summaryrefslogtreecommitdiff
path: root/utils/sampler.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-26 21:36:10 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-26 21:36:10 +0800
commit6a767f503f24017bc4cd391e080414d8730d081b (patch)
tree8473742dfe0b6f81491d7d832a689d5b36a3fc90 /utils/sampler.py
parentd5f7cdab1466566d805f9cbf81c05767880886ae (diff)
Sample k more clips for disentanglement
Diffstat (limited to 'utils/sampler.py')
-rw-r--r--utils/sampler.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/utils/sampler.py b/utils/sampler.py
index 1dd33ca..8dec846 100644
--- a/utils/sampler.py
+++ b/utils/sampler.py
@@ -14,23 +14,27 @@ class TripletSampler(data.Sampler):
batch_size: Tuple[int, int]
):
super().__init__(data_source)
- self.metadata_label = data_source.metadata['labels']
+ self.metadata_labels = data_source.metadata['labels']
self.labels = data_source.labels
self.length = len(self.labels)
self.indexes = np.arange(0, self.length)
- (self.P, self.K) = batch_size
+ (self.pr, self.k) = batch_size
def __iter__(self) -> Iterator[int]:
while True:
sampled_indexes = []
- sampled_labels = random.sample(self.metadata_label, k=self.P)
- for label in sampled_labels:
- clip_indexes = list(self.indexes[self.labels == label])
- # Sample without replacement if have enough clips
- if len(clip_indexes) >= self.K:
- _sampled_indexes = random.sample(clip_indexes, k=self.K)
+ # Sample pr subjects by sampling labels appeared in dataset
+ sampled_subjects = random.sample(self.metadata_labels, k=self.pr)
+ for label in sampled_subjects:
+ clips_from_subject = self.indexes[self.labels == label].tolist()
+ # Sample k clips from the subject without replacement if
+ # have enough clips, k more clips will sampled for
+ # disentanglement
+ k = self.k * 2
+ if len(clips_from_subject) >= k:
+ _sampled_indexes = random.sample(clips_from_subject, k=k)
else:
- _sampled_indexes = random.choices(clip_indexes, k=self.K)
+ _sampled_indexes = random.choices(clips_from_subject, k=k)
sampled_indexes += _sampled_indexes
yield sampled_indexes