From 31b20fbff0786c998c54b8585de759d02f41eda7 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 26 Dec 2020 22:14:58 +0800 Subject: Implement batch splitter to split sampled data Disentanglement cannot be processed on different subjects at the same time, we need to load `pr` subjects one by one. The batch splitter will return a pr-length list of tuples (with 2 dicts containing k-length lists of labels, conditions, view and k-length tensor of clip data, representing condition 1 and condition 2 respectively). --- models/model.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 models/model.py (limited to 'models') diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000..369d6c2 --- /dev/null +++ b/models/model.py @@ -0,0 +1,32 @@ +from typing import List, Dict, Union, Tuple + +import torch +from torch.utils.data.dataloader import default_collate + + +class Model: + def __init__( + self, + batch_size: Tuple[int, int] + ): + (self.pr, self.k) = batch_size + + def _batch_splitter( + self, + batch: List[Dict[str, Union[str, torch.Tensor]]] + ) -> List[Tuple[Dict[str, List[Union[str, torch.Tensor]]], + Dict[str, List[Union[str, torch.Tensor]]]]]: + """ + Disentanglement cannot be processed on different subjects at the + same time, we need to load `pr` subjects one by one. The batch + splitter will return a pr-length list of tuples (with 2 dicts + containing k-length lists of labels, conditions, view and + k-length tensor of clip data, representing condition 1 and + condition 2 respectively). + """ + _batch = [] + for i in range(0, self.pr * self.k * 2, self.k * 2): + _batch.append((default_collate(batch[i:i + self.k]), + default_collate(batch[i + self.k:i + self.k * 2]))) + + return _batch -- cgit v1.2.3