From 6a767f503f24017bc4cd391e080414d8730d081b Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Sat, 26 Dec 2020 21:36:10 +0800
Subject: Sample k more clips for disentanglement

---
 utils/sampler.py | 22 +++++++++++++---------
 1 file changed, 13 insertions(+), 9 deletions(-)

(limited to 'utils')

diff --git a/utils/sampler.py b/utils/sampler.py
index 1dd33ca..8dec846 100644
--- a/utils/sampler.py
+++ b/utils/sampler.py
@@ -14,23 +14,27 @@ class TripletSampler(data.Sampler):
             batch_size: Tuple[int, int]
     ):
         super().__init__(data_source)
-        self.metadata_label = data_source.metadata['labels']
+        self.metadata_labels = data_source.metadata['labels']
         self.labels = data_source.labels
         self.length = len(self.labels)
         self.indexes = np.arange(0, self.length)
-        (self.P, self.K) = batch_size
+        (self.pr, self.k) = batch_size
 
     def __iter__(self) -> Iterator[int]:
         while True:
             sampled_indexes = []
-            sampled_labels = random.sample(self.metadata_label, k=self.P)
-            for label in sampled_labels:
-                clip_indexes = list(self.indexes[self.labels == label])
-                # Sample without replacement if have enough clips
-                if len(clip_indexes) >= self.K:
-                    _sampled_indexes = random.sample(clip_indexes, k=self.K)
+            # Sample pr subjects by sampling labels appeared in dataset
+            sampled_subjects = random.sample(self.metadata_labels, k=self.pr)
+            for label in sampled_subjects:
+                clips_from_subject = self.indexes[self.labels == label].tolist()
+                # Sample k clips from the subject without replacement if
+                # have enough clips, k more clips will sampled for
+                # disentanglement
+                k = self.k * 2
+                if len(clips_from_subject) >= k:
+                    _sampled_indexes = random.sample(clips_from_subject, k=k)
                 else:
-                    _sampled_indexes = random.choices(clip_indexes, k=self.K)
+                    _sampled_indexes = random.choices(clips_from_subject, k=k)
                 sampled_indexes += _sampled_indexes
 
             yield sampled_indexes
-- 
cgit v1.2.3