diff options
-rw-r--r-- | config.py | 8 | ||||
-rw-r--r-- | models/model.py | 28 | ||||
-rw-r--r-- | utils/configuration.py | 2 | ||||
-rw-r--r-- | utils/sampler.py | 20 |
4 files changed, 24 insertions, 34 deletions
@@ -11,7 +11,7 @@ config: Configuration = { # Recorde disentangled image or not 'image_log_on': True, # The number of subjects for validating (Part of testing set) - 'val_size': 10, + 'val_size': 20, }, # Dataset settings 'dataset': { @@ -36,10 +36,8 @@ config: Configuration = { }, # Dataloader settings 'dataloader': { - # Batch size (pr, k) - # `pr` denotes number of persons - # `k` denotes number of sequences per person - 'batch_size': (4, 6), + # Batch size + 'batch_size': 16, # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled diff --git a/models/model.py b/models/model.py index 667a0a7..46987ca 100644 --- a/models/model.py +++ b/models/model.py @@ -17,7 +17,7 @@ from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses -from utils.sampler import TripletSampler +from utils.sampler import DisentanglingSampler class Model: @@ -56,8 +56,7 @@ class Model: self.is_train: bool = True self.in_channels: int = 3 self.in_size: Tuple[int, int] = (64, 48) - self.pr: Optional[int] = None - self.k: Optional[int] = None + self.batch_size: Optional[int] = None self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None @@ -91,7 +90,7 @@ class Model: @property def _checkpoint_sig(self) -> str: return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, - str(self.pr), str(self.k))) + str(self.batch_size))) @property def _checkpoint_name(self) -> str: @@ -100,7 +99,7 @@ class Model: @property def _log_sig(self) -> str: return '_'.join((self._model_name, str(self.total_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + self._dataset_sig, str(self.batch_size))) @property def _log_name(self) -> str: @@ -441,11 +440,11 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: Dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size', (8, 16)) + self.batch_size = config.pop('batch_size', 16) if self.is_train: - triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) + dis_sampler = DisentanglingSampler(dataset, self.batch_size) return DataLoader(dataset, - batch_sampler=triplet_sampler, + batch_sampler=dis_sampler, collate_fn=self._batch_splitter, **config) else: # is_test @@ -458,15 +457,14 @@ class Model: Dict[str, Union[List[str], torch.Tensor]]]: """ Disentanglement need two random conditions, this function will - split pr * k * 2 samples to 2 dicts each containing pr * k - samples. labels and clip data are tensor, and others are list. + split batch_size * 2 samples to 2 dicts each containing + batch_size samples. labels and clip data are tensor, and others + are list. """ - _batch = [[], []] - for i in range(0, self.pr * self.k * 2, self.k * 2): - _batch[0] += batch[i:i + self.k] - _batch[1] += batch[i + self.k:i + self.k * 2] + batch_0 = batch[slice(0, self.batch_size * 2, 2)] + batch_1 = batch[slice(1, self.batch_size * 2, 2)] - return default_collate(_batch[0]), default_collate(_batch[1]) + return default_collate(batch_0), default_collate(batch_1) def _make_signature(self, config: Dict, diff --git a/utils/configuration.py b/utils/configuration.py index 6f04c68..651689d 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -25,7 +25,7 @@ class DatasetConfiguration(TypedDict): class DataloaderConfiguration(TypedDict): - batch_size: Tuple[int, int] + batch_size: int num_workers: int pin_memory: bool diff --git a/utils/sampler.py b/utils/sampler.py index 581d7a2..b017d66 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -7,11 +7,11 @@ from torch.utils import data from utils.dataset import CASIAB -class TripletSampler(data.Sampler): +class DisentanglingSampler(data.Sampler): def __init__( self, data_source: Union[CASIAB], - batch_size: Tuple[int, int] + batch_size: int ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] @@ -29,13 +29,14 @@ class TripletSampler(data.Sampler): self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) - (self.pr, self.k) = batch_size + self.batch_size = batch_size def __iter__(self) -> Iterator[int]: while True: sampled_indexes = [] - # Sample pr subjects by sampling labels appeared in dataset - sampled_subjects = random.sample(self.metadata_labels, k=self.pr) + sampled_subjects = random.sample( + self.metadata_labels, k=self.batch_size + ) for label in sampled_subjects: mask = self.labels == label # Fix unbalanced datasets @@ -54,14 +55,7 @@ class TripletSampler(data.Sampler): condition_mask |= self.conditions == condition mask &= condition_mask clips = self.indexes[mask].tolist() - # Sample k clips from the subject without replacement if - # have enough clips, k more clips will sampled for - # disentanglement - k = self.k * 2 - if len(clips) >= k: - _sampled_indexes = random.sample(clips, k=k) - else: - _sampled_indexes = random.choices(clips, k=k) + _sampled_indexes = random.sample(clips, k=2) sampled_indexes += _sampled_indexes yield sampled_indexes |