summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py31
1 files changed, 13 insertions, 18 deletions
diff --git a/models/model.py b/models/model.py
index ca1497f..e9714b8 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,6 @@
from typing import Union, Optional
+import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
@@ -45,23 +46,20 @@ class Model:
def _batch_splitter(
self,
- batch: list[dict[str, Union[str, torch.Tensor]]]
- ) -> list[tuple[dict[str, list[Union[str, torch.Tensor]]],
- dict[str, list[Union[str, torch.Tensor]]]]]:
+ batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
+ dict[str, Union[list[str], torch.Tensor]]]:
"""
- Disentanglement cannot be processed on different subjects at the
- same time, we need to load `pr` subjects one by one. The batch
- splitter will return a pr-length list of tuples (with 2 dicts
- containing k-length lists of labels, conditions, view and
- k-length tensor of clip data, representing condition 1 and
- condition 2 respectively).
+ Disentanglement need two random conditions, this function will
+ split pr * k * 2 samples to 2 dicts each containing pr * k
+ samples. labels and clip data are tensor, and others are list.
"""
- _batch = []
+ _batch = [[], []]
for i in range(0, self.pr * self.k * 2, self.k * 2):
- _batch.append((default_collate(batch[i:i + self.k]),
- default_collate(batch[i + self.k:i + self.k * 2])))
+ _batch[0] += batch[i:i + self.k]
+ _batch[1] += batch[i + self.k:i + self.k * 2]
- return _batch
+ return default_collate(_batch[0]), default_collate(_batch[1])
def fit(
self,
@@ -71,12 +69,9 @@ class Model:
self.is_train = True
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
- for iter_i, samples_batched in enumerate(dataloader):
- for sub_i, (subject_c1, subject_c2) in enumerate(samples_batched):
- pass
+ for iter_i, (samples_c1, samples_c2) in enumerate(dataloader):
+ pass
- if sub_i == 0:
- break
if iter_i == 0:
break