From a08c7f31d6a9f499be9d13a1059ffc06aeb75bbf Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 18 Dec 2020 22:37:52 +0800 Subject: Implement triplet sampler --- utils/sampler.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 utils/sampler.py (limited to 'utils') diff --git a/utils/sampler.py b/utils/sampler.py new file mode 100644 index 0000000..1dd33ca --- /dev/null +++ b/utils/sampler.py @@ -0,0 +1,39 @@ +import random +from typing import Iterator, Tuple + +import numpy as np +from torch.utils import data + +from utils.dataset import CASIAB + + +class TripletSampler(data.Sampler): + def __init__( + self, + data_source: CASIAB, + batch_size: Tuple[int, int] + ): + super().__init__(data_source) + self.metadata_label = data_source.metadata['labels'] + self.labels = data_source.labels + self.length = len(self.labels) + self.indexes = np.arange(0, self.length) + (self.P, self.K) = batch_size + + def __iter__(self) -> Iterator[int]: + while True: + sampled_indexes = [] + sampled_labels = random.sample(self.metadata_label, k=self.P) + for label in sampled_labels: + clip_indexes = list(self.indexes[self.labels == label]) + # Sample without replacement if have enough clips + if len(clip_indexes) >= self.K: + _sampled_indexes = random.sample(clip_indexes, k=self.K) + else: + _sampled_indexes = random.choices(clip_indexes, k=self.K) + sampled_indexes += _sampled_indexes + + yield sampled_indexes + + def __len__(self) -> int: + return self.length -- cgit v1.2.3