diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-26 21:36:10 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-26 21:36:10 +0800 |
commit | 6a767f503f24017bc4cd391e080414d8730d081b (patch) | |
tree | 8473742dfe0b6f81491d7d832a689d5b36a3fc90 /utils/sampler.py | |
parent | d5f7cdab1466566d805f9cbf81c05767880886ae (diff) |
Sample k more clips for disentanglement
Diffstat (limited to 'utils/sampler.py')
-rw-r--r-- | utils/sampler.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/utils/sampler.py b/utils/sampler.py index 1dd33ca..8dec846 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -14,23 +14,27 @@ class TripletSampler(data.Sampler): batch_size: Tuple[int, int] ): super().__init__(data_source) - self.metadata_label = data_source.metadata['labels'] + self.metadata_labels = data_source.metadata['labels'] self.labels = data_source.labels self.length = len(self.labels) self.indexes = np.arange(0, self.length) - (self.P, self.K) = batch_size + (self.pr, self.k) = batch_size def __iter__(self) -> Iterator[int]: while True: sampled_indexes = [] - sampled_labels = random.sample(self.metadata_label, k=self.P) - for label in sampled_labels: - clip_indexes = list(self.indexes[self.labels == label]) - # Sample without replacement if have enough clips - if len(clip_indexes) >= self.K: - _sampled_indexes = random.sample(clip_indexes, k=self.K) + # Sample pr subjects by sampling labels appeared in dataset + sampled_subjects = random.sample(self.metadata_labels, k=self.pr) + for label in sampled_subjects: + clips_from_subject = self.indexes[self.labels == label].tolist() + # Sample k clips from the subject without replacement if + # have enough clips, k more clips will sampled for + # disentanglement + k = self.k * 2 + if len(clips_from_subject) >= k: + _sampled_indexes = random.sample(clips_from_subject, k=k) else: - _sampled_indexes = random.choices(clip_indexes, k=self.K) + _sampled_indexes = random.choices(clips_from_subject, k=k) sampled_indexes += _sampled_indexes yield sampled_indexes |