diff options
-rw-r--r-- | utils/sampler.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/utils/sampler.py b/utils/sampler.py index 1dd33ca..8dec846 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -14,23 +14,27 @@ class TripletSampler(data.Sampler): batch_size: Tuple[int, int] ): super().__init__(data_source) - self.metadata_label = data_source.metadata['labels'] + self.metadata_labels = data_source.metadata['labels'] self.labels = data_source.labels self.length = len(self.labels) self.indexes = np.arange(0, self.length) - (self.P, self.K) = batch_size + (self.pr, self.k) = batch_size def __iter__(self) -> Iterator[int]: while True: sampled_indexes = [] - sampled_labels = random.sample(self.metadata_label, k=self.P) - for label in sampled_labels: - clip_indexes = list(self.indexes[self.labels == label]) - # Sample without replacement if have enough clips - if len(clip_indexes) >= self.K: - _sampled_indexes = random.sample(clip_indexes, k=self.K) + # Sample pr subjects by sampling labels appeared in dataset + sampled_subjects = random.sample(self.metadata_labels, k=self.pr) + for label in sampled_subjects: + clips_from_subject = self.indexes[self.labels == label].tolist() + # Sample k clips from the subject without replacement if + # have enough clips, k more clips will sampled for + # disentanglement + k = self.k * 2 + if len(clips_from_subject) >= k: + _sampled_indexes = random.sample(clips_from_subject, k=k) else: - _sampled_indexes = random.choices(clip_indexes, k=self.K) + _sampled_indexes = random.choices(clips_from_subject, k=k) sampled_indexes += _sampled_indexes yield sampled_indexes |