diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-26 22:14:58 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-26 22:14:58 +0800 |
commit | 31b20fbff0786c998c54b8585de759d02f41eda7 (patch) | |
tree | 1838e752dffa1b132d6120252709208d2711fa55 /models | |
parent | 6a767f503f24017bc4cd391e080414d8730d081b (diff) |
Implement batch splitter to split sampled data
Disentanglement cannot be processed on different subjects at the same time, we need to load `pr` subjects one by one. The batch splitter will return a pr-length list of tuples (with 2 dicts containing k-length lists of labels, conditions, view and k-length tensor of clip data, representing condition 1 and condition 2 respectively).
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000..369d6c2 --- /dev/null +++ b/models/model.py @@ -0,0 +1,32 @@ +from typing import List, Dict, Union, Tuple + +import torch +from torch.utils.data.dataloader import default_collate + + +class Model: + def __init__( + self, + batch_size: Tuple[int, int] + ): + (self.pr, self.k) = batch_size + + def _batch_splitter( + self, + batch: List[Dict[str, Union[str, torch.Tensor]]] + ) -> List[Tuple[Dict[str, List[Union[str, torch.Tensor]]], + Dict[str, List[Union[str, torch.Tensor]]]]]: + """ + Disentanglement cannot be processed on different subjects at the + same time, we need to load `pr` subjects one by one. The batch + splitter will return a pr-length list of tuples (with 2 dicts + containing k-length lists of labels, conditions, view and + k-length tensor of clip data, representing condition 1 and + condition 2 respectively). + """ + _batch = [] + for i in range(0, self.pr * self.k * 2, self.k * 2): + _batch.append((default_collate(batch[i:i + self.k]), + default_collate(batch[i + self.k:i + self.k * 2]))) + + return _batch |