diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-18 22:37:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-18 22:37:52 +0800 |
commit | a08c7f31d6a9f499be9d13a1059ffc06aeb75bbf (patch) | |
tree | 66d6b2af3c0937a294d21256e3b57d21dd8a4d0e /utils | |
parent | 1db7d5cefbd14f14c8393e862d08fa9c620f90f6 (diff) |
Implement triplet sampler
Diffstat (limited to 'utils')
-rw-r--r-- | utils/sampler.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/utils/sampler.py b/utils/sampler.py new file mode 100644 index 0000000..1dd33ca --- /dev/null +++ b/utils/sampler.py @@ -0,0 +1,39 @@ +import random +from typing import Iterator, Tuple + +import numpy as np +from torch.utils import data + +from utils.dataset import CASIAB + + +class TripletSampler(data.Sampler): + def __init__( + self, + data_source: CASIAB, + batch_size: Tuple[int, int] + ): + super().__init__(data_source) + self.metadata_label = data_source.metadata['labels'] + self.labels = data_source.labels + self.length = len(self.labels) + self.indexes = np.arange(0, self.length) + (self.P, self.K) = batch_size + + def __iter__(self) -> Iterator[int]: + while True: + sampled_indexes = [] + sampled_labels = random.sample(self.metadata_label, k=self.P) + for label in sampled_labels: + clip_indexes = list(self.indexes[self.labels == label]) + # Sample without replacement if have enough clips + if len(clip_indexes) >= self.K: + _sampled_indexes = random.sample(clip_indexes, k=self.K) + else: + _sampled_indexes = random.choices(clip_indexes, k=self.K) + sampled_indexes += _sampled_indexes + + yield sampled_indexes + + def __len__(self) -> int: + return self.length |