diff options
-rw-r--r-- | models/model.py | 31 |
1 files changed, 13 insertions, 18 deletions
diff --git a/models/model.py b/models/model.py index ca1497f..e9714b8 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,6 @@ from typing import Union, Optional +import numpy as np import torch from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate @@ -45,23 +46,20 @@ class Model: def _batch_splitter( self, - batch: list[dict[str, Union[str, torch.Tensor]]] - ) -> list[tuple[dict[str, list[Union[str, torch.Tensor]]], - dict[str, list[Union[str, torch.Tensor]]]]]: + batch: list[dict[str, Union[np.int64, str, torch.Tensor]]] + ) -> tuple[dict[str, Union[list[str], torch.Tensor]], + dict[str, Union[list[str], torch.Tensor]]]: """ - Disentanglement cannot be processed on different subjects at the - same time, we need to load `pr` subjects one by one. The batch - splitter will return a pr-length list of tuples (with 2 dicts - containing k-length lists of labels, conditions, view and - k-length tensor of clip data, representing condition 1 and - condition 2 respectively). + Disentanglement need two random conditions, this function will + split pr * k * 2 samples to 2 dicts each containing pr * k + samples. labels and clip data are tensor, and others are list. """ - _batch = [] + _batch = [[], []] for i in range(0, self.pr * self.k * 2, self.k * 2): - _batch.append((default_collate(batch[i:i + self.k]), - default_collate(batch[i + self.k:i + self.k * 2]))) + _batch[0] += batch[i:i + self.k] + _batch[1] += batch[i + self.k:i + self.k * 2] - return _batch + return default_collate(_batch[0]), default_collate(_batch[1]) def fit( self, @@ -71,12 +69,9 @@ class Model: self.is_train = True dataset = self._parse_dataset_config(dataset_config) dataloader = self._parse_dataloader_config(dataset, dataloader_config) - for iter_i, samples_batched in enumerate(dataloader): - for sub_i, (subject_c1, subject_c2) in enumerate(samples_batched): - pass + for iter_i, (samples_c1, samples_c2) in enumerate(dataloader): + pass - if sub_i == 0: - break if iter_i == 0: break |