diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-04 12:45:36 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-04 12:45:36 +0800 |
commit | 6a8824e4fb8bdd1f3e763b78b765830788415cfb (patch) | |
tree | 1853fce051cb9c46c7d253d7ddcec3cde0478899 /models | |
parent | 2f4516549c7b9ebfaf650c6a71e4da43cd3372fa (diff) | |
parent | cb05de36f5ffd8584d78c6776dbe90e21abff25a (diff) |
Merge branch 'disentangling_only' into disentangling_only_py3.8
# Conflicts:
# models/model.py
# utils/configuration.py
# utils/sampler.py
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 28 |
1 files changed, 13 insertions, 15 deletions
diff --git a/models/model.py b/models/model.py index 667a0a7..46987ca 100644 --- a/models/model.py +++ b/models/model.py @@ -17,7 +17,7 @@ from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses -from utils.sampler import TripletSampler +from utils.sampler import DisentanglingSampler class Model: @@ -56,8 +56,7 @@ class Model: self.is_train: bool = True self.in_channels: int = 3 self.in_size: Tuple[int, int] = (64, 48) - self.pr: Optional[int] = None - self.k: Optional[int] = None + self.batch_size: Optional[int] = None self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None @@ -91,7 +90,7 @@ class Model: @property def _checkpoint_sig(self) -> str: return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, - str(self.pr), str(self.k))) + str(self.batch_size))) @property def _checkpoint_name(self) -> str: @@ -100,7 +99,7 @@ class Model: @property def _log_sig(self) -> str: return '_'.join((self._model_name, str(self.total_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + self._dataset_sig, str(self.batch_size))) @property def _log_name(self) -> str: @@ -441,11 +440,11 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: Dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size', (8, 16)) + self.batch_size = config.pop('batch_size', 16) if self.is_train: - triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) + dis_sampler = DisentanglingSampler(dataset, self.batch_size) return DataLoader(dataset, - batch_sampler=triplet_sampler, + batch_sampler=dis_sampler, collate_fn=self._batch_splitter, **config) else: # is_test @@ -458,15 +457,14 @@ class Model: Dict[str, Union[List[str], torch.Tensor]]]: """ Disentanglement need two random conditions, this function will - split pr * k * 2 samples to 2 dicts each containing pr * k - samples. labels and clip data are tensor, and others are list. + split batch_size * 2 samples to 2 dicts each containing + batch_size samples. labels and clip data are tensor, and others + are list. """ - _batch = [[], []] - for i in range(0, self.pr * self.k * 2, self.k * 2): - _batch[0] += batch[i:i + self.k] - _batch[1] += batch[i + self.k:i + self.k * 2] + batch_0 = batch[slice(0, self.batch_size * 2, 2)] + batch_1 = batch[slice(1, self.batch_size * 2, 2)] - return default_collate(_batch[0]), default_collate(_batch[1]) + return default_collate(batch_0), default_collate(batch_1) def _make_signature(self, config: Dict, |