summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py28
1 files changed, 13 insertions, 15 deletions
diff --git a/models/model.py b/models/model.py
index 25c8a4f..0829f33 100644
--- a/models/model.py
+++ b/models/model.py
@@ -17,7 +17,7 @@ from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
-from utils.sampler import TripletSampler
+from utils.sampler import DisentanglingSampler
class Model:
@@ -56,8 +56,7 @@ class Model:
self.is_train: bool = True
self.in_channels: int = 3
self.in_size: tuple[int, int] = (64, 48)
- self.pr: Optional[int] = None
- self.k: Optional[int] = None
+ self.batch_size: Optional[int] = None
self._gallery_dataset_meta: Optional[dict[str, list]] = None
self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
@@ -91,7 +90,7 @@ class Model:
@property
def _checkpoint_sig(self) -> str:
return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig,
- str(self.pr), str(self.k)))
+ str(self.batch_size)))
@property
def _checkpoint_name(self) -> str:
@@ -100,7 +99,7 @@ class Model:
@property
def _log_sig(self) -> str:
return '_'.join((self._model_name, str(self.total_iter), self._hp_sig,
- self._dataset_sig, str(self.pr), str(self.k)))
+ self._dataset_sig, str(self.batch_size)))
@property
def _log_name(self) -> str:
@@ -441,11 +440,11 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: dict = dataloader_config.copy()
- (self.pr, self.k) = config.pop('batch_size', (8, 16))
+ self.batch_size = config.pop('batch_size', 16)
if self.is_train:
- triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
+ dis_sampler = DisentanglingSampler(dataset, self.batch_size)
return DataLoader(dataset,
- batch_sampler=triplet_sampler,
+ batch_sampler=dis_sampler,
collate_fn=self._batch_splitter,
**config)
else: # is_test
@@ -458,15 +457,14 @@ class Model:
dict[str, Union[list[str], torch.Tensor]]]:
"""
Disentanglement need two random conditions, this function will
- split pr * k * 2 samples to 2 dicts each containing pr * k
- samples. labels and clip data are tensor, and others are list.
+ split batch_size * 2 samples to 2 dicts each containing
+ batch_size samples. labels and clip data are tensor, and others
+ are list.
"""
- _batch = [[], []]
- for i in range(0, self.pr * self.k * 2, self.k * 2):
- _batch[0] += batch[i:i + self.k]
- _batch[1] += batch[i + self.k:i + self.k * 2]
+ batch_0 = batch[slice(0, self.batch_size * 2, 2)]
+ batch_1 = batch[slice(1, self.batch_size * 2, 2)]
- return default_collate(_batch[0]), default_collate(_batch[1])
+ return default_collate(batch_0), default_collate(batch_1)
def _make_signature(self,
config: dict,