From cb05de36f5ffd8584d78c6776dbe90e21abff25a Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 4 Apr 2021 12:44:40 +0800 Subject: Remove triplet sampler --- models/model.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 25c8a4f..0829f33 100644 --- a/models/model.py +++ b/models/model.py @@ -17,7 +17,7 @@ from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses -from utils.sampler import TripletSampler +from utils.sampler import DisentanglingSampler class Model: @@ -56,8 +56,7 @@ class Model: self.is_train: bool = True self.in_channels: int = 3 self.in_size: tuple[int, int] = (64, 48) - self.pr: Optional[int] = None - self.k: Optional[int] = None + self.batch_size: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -91,7 +90,7 @@ class Model: @property def _checkpoint_sig(self) -> str: return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, - str(self.pr), str(self.k))) + str(self.batch_size))) @property def _checkpoint_name(self) -> str: @@ -100,7 +99,7 @@ class Model: @property def _log_sig(self) -> str: return '_'.join((self._model_name, str(self.total_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + self._dataset_sig, str(self.batch_size))) @property def _log_name(self) -> str: @@ -441,11 +440,11 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size', (8, 16)) + self.batch_size = config.pop('batch_size', 16) if self.is_train: - triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) + dis_sampler = DisentanglingSampler(dataset, self.batch_size) return DataLoader(dataset, - batch_sampler=triplet_sampler, + batch_sampler=dis_sampler, collate_fn=self._batch_splitter, **config) else: # is_test @@ -458,15 +457,14 @@ class Model: dict[str, Union[list[str], torch.Tensor]]]: """ Disentanglement need two random conditions, this function will - split pr * k * 2 samples to 2 dicts each containing pr * k - samples. labels and clip data are tensor, and others are list. + split batch_size * 2 samples to 2 dicts each containing + batch_size samples. labels and clip data are tensor, and others + are list. """ - _batch = [[], []] - for i in range(0, self.pr * self.k * 2, self.k * 2): - _batch[0] += batch[i:i + self.k] - _batch[1] += batch[i + self.k:i + self.k * 2] + batch_0 = batch[slice(0, self.batch_size * 2, 2)] + batch_1 = batch[slice(1, self.batch_size * 2, 2)] - return default_collate(_batch[0]), default_collate(_batch[1]) + return default_collate(batch_0), default_collate(batch_1) def _make_signature(self, config: dict, -- cgit v1.2.3