summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py8
-rw-r--r--models/model.py28
-rw-r--r--utils/configuration.py2
-rw-r--r--utils/sampler.py20
4 files changed, 24 insertions, 34 deletions
diff --git a/config.py b/config.py
index dc4e0ba..cf76618 100644
--- a/config.py
+++ b/config.py
@@ -11,7 +11,7 @@ config: Configuration = {
# Recorde disentangled image or not
'image_log_on': True,
# The number of subjects for validating (Part of testing set)
- 'val_size': 10,
+ 'val_size': 20,
},
# Dataset settings
'dataset': {
@@ -36,10 +36,8 @@ config: Configuration = {
},
# Dataloader settings
'dataloader': {
- # Batch size (pr, k)
- # `pr` denotes number of persons
- # `k` denotes number of sequences per person
- 'batch_size': (4, 6),
+ # Batch size
+ 'batch_size': 16,
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
diff --git a/models/model.py b/models/model.py
index 25c8a4f..0829f33 100644
--- a/models/model.py
+++ b/models/model.py
@@ -17,7 +17,7 @@ from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
-from utils.sampler import TripletSampler
+from utils.sampler import DisentanglingSampler
class Model:
@@ -56,8 +56,7 @@ class Model:
self.is_train: bool = True
self.in_channels: int = 3
self.in_size: tuple[int, int] = (64, 48)
- self.pr: Optional[int] = None
- self.k: Optional[int] = None
+ self.batch_size: Optional[int] = None
self._gallery_dataset_meta: Optional[dict[str, list]] = None
self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
@@ -91,7 +90,7 @@ class Model:
@property
def _checkpoint_sig(self) -> str:
return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig,
- str(self.pr), str(self.k)))
+ str(self.batch_size)))
@property
def _checkpoint_name(self) -> str:
@@ -100,7 +99,7 @@ class Model:
@property
def _log_sig(self) -> str:
return '_'.join((self._model_name, str(self.total_iter), self._hp_sig,
- self._dataset_sig, str(self.pr), str(self.k)))
+ self._dataset_sig, str(self.batch_size)))
@property
def _log_name(self) -> str:
@@ -441,11 +440,11 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: dict = dataloader_config.copy()
- (self.pr, self.k) = config.pop('batch_size', (8, 16))
+ self.batch_size = config.pop('batch_size', 16)
if self.is_train:
- triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
+ dis_sampler = DisentanglingSampler(dataset, self.batch_size)
return DataLoader(dataset,
- batch_sampler=triplet_sampler,
+ batch_sampler=dis_sampler,
collate_fn=self._batch_splitter,
**config)
else: # is_test
@@ -458,15 +457,14 @@ class Model:
dict[str, Union[list[str], torch.Tensor]]]:
"""
Disentanglement need two random conditions, this function will
- split pr * k * 2 samples to 2 dicts each containing pr * k
- samples. labels and clip data are tensor, and others are list.
+ split batch_size * 2 samples to 2 dicts each containing
+ batch_size samples. labels and clip data are tensor, and others
+ are list.
"""
- _batch = [[], []]
- for i in range(0, self.pr * self.k * 2, self.k * 2):
- _batch[0] += batch[i:i + self.k]
- _batch[1] += batch[i + self.k:i + self.k * 2]
+ batch_0 = batch[slice(0, self.batch_size * 2, 2)]
+ batch_1 = batch[slice(1, self.batch_size * 2, 2)]
- return default_collate(_batch[0]), default_collate(_batch[1])
+ return default_collate(batch_0), default_collate(batch_1)
def _make_signature(self,
config: dict,
diff --git a/utils/configuration.py b/utils/configuration.py
index fff3876..1ace241 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -25,7 +25,7 @@ class DatasetConfiguration(TypedDict):
class DataloaderConfiguration(TypedDict):
- batch_size: tuple[int, int]
+ batch_size: int
num_workers: int
pin_memory: bool
diff --git a/utils/sampler.py b/utils/sampler.py
index 0c9872c..e609e2d 100644
--- a/utils/sampler.py
+++ b/utils/sampler.py
@@ -8,11 +8,11 @@ from torch.utils import data
from utils.dataset import CASIAB
-class TripletSampler(data.Sampler):
+class DisentanglingSampler(data.Sampler):
def __init__(
self,
data_source: Union[CASIAB],
- batch_size: tuple[int, int]
+ batch_size: int
):
super().__init__(data_source)
self.metadata_labels = data_source.metadata['labels']
@@ -30,13 +30,14 @@ class TripletSampler(data.Sampler):
self.conditions = data_source.conditions
self.length = len(self.labels)
self.indexes = np.arange(0, self.length)
- (self.pr, self.k) = batch_size
+ self.batch_size = batch_size
def __iter__(self) -> Iterator[int]:
while True:
sampled_indexes = []
- # Sample pr subjects by sampling labels appeared in dataset
- sampled_subjects = random.sample(self.metadata_labels, k=self.pr)
+ sampled_subjects = random.sample(
+ self.metadata_labels, k=self.batch_size
+ )
for label in sampled_subjects:
mask = self.labels == label
# Fix unbalanced datasets
@@ -55,14 +56,7 @@ class TripletSampler(data.Sampler):
condition_mask |= self.conditions == condition
mask &= condition_mask
clips = self.indexes[mask].tolist()
- # Sample k clips from the subject without replacement if
- # have enough clips, k more clips will sampled for
- # disentanglement
- k = self.k * 2
- if len(clips) >= k:
- _sampled_indexes = random.sample(clips, k=k)
- else:
- _sampled_indexes = random.choices(clips, k=k)
+ _sampled_indexes = random.sample(clips, k=2)
sampled_indexes += _sampled_indexes
yield sampled_indexes