diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/configuration.py | 2 | ||||
-rw-r--r-- | utils/sampler.py | 20 |
2 files changed, 8 insertions, 14 deletions
diff --git a/utils/configuration.py b/utils/configuration.py index fff3876..1ace241 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -25,7 +25,7 @@ class DatasetConfiguration(TypedDict): class DataloaderConfiguration(TypedDict): - batch_size: tuple[int, int] + batch_size: int num_workers: int pin_memory: bool diff --git a/utils/sampler.py b/utils/sampler.py index 0c9872c..e609e2d 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -8,11 +8,11 @@ from torch.utils import data from utils.dataset import CASIAB -class TripletSampler(data.Sampler): +class DisentanglingSampler(data.Sampler): def __init__( self, data_source: Union[CASIAB], - batch_size: tuple[int, int] + batch_size: int ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] @@ -30,13 +30,14 @@ class TripletSampler(data.Sampler): self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) - (self.pr, self.k) = batch_size + self.batch_size = batch_size def __iter__(self) -> Iterator[int]: while True: sampled_indexes = [] - # Sample pr subjects by sampling labels appeared in dataset - sampled_subjects = random.sample(self.metadata_labels, k=self.pr) + sampled_subjects = random.sample( + self.metadata_labels, k=self.batch_size + ) for label in sampled_subjects: mask = self.labels == label # Fix unbalanced datasets @@ -55,14 +56,7 @@ class TripletSampler(data.Sampler): condition_mask |= self.conditions == condition mask &= condition_mask clips = self.indexes[mask].tolist() - # Sample k clips from the subject without replacement if - # have enough clips, k more clips will sampled for - # disentanglement - k = self.k * 2 - if len(clips) >= k: - _sampled_indexes = random.sample(clips, k=k) - else: - _sampled_indexes = random.choices(clips, k=k) + _sampled_indexes = random.sample(clips, k=2) sampled_indexes += _sampled_indexes yield sampled_indexes |