diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-04 12:45:36 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-04 12:45:36 +0800 |
commit | 6a8824e4fb8bdd1f3e763b78b765830788415cfb (patch) | |
tree | 1853fce051cb9c46c7d253d7ddcec3cde0478899 /utils | |
parent | 2f4516549c7b9ebfaf650c6a71e4da43cd3372fa (diff) | |
parent | cb05de36f5ffd8584d78c6776dbe90e21abff25a (diff) |
Merge branch 'disentangling_only' into disentangling_only_py3.8
# Conflicts:
# models/model.py
# utils/configuration.py
# utils/sampler.py
Diffstat (limited to 'utils')
-rw-r--r-- | utils/configuration.py | 2 | ||||
-rw-r--r-- | utils/sampler.py | 20 |
2 files changed, 8 insertions, 14 deletions
diff --git a/utils/configuration.py b/utils/configuration.py index 6f04c68..651689d 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -25,7 +25,7 @@ class DatasetConfiguration(TypedDict): class DataloaderConfiguration(TypedDict): - batch_size: Tuple[int, int] + batch_size: int num_workers: int pin_memory: bool diff --git a/utils/sampler.py b/utils/sampler.py index 581d7a2..b017d66 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -7,11 +7,11 @@ from torch.utils import data from utils.dataset import CASIAB -class TripletSampler(data.Sampler): +class DisentanglingSampler(data.Sampler): def __init__( self, data_source: Union[CASIAB], - batch_size: Tuple[int, int] + batch_size: int ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] @@ -29,13 +29,14 @@ class TripletSampler(data.Sampler): self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) - (self.pr, self.k) = batch_size + self.batch_size = batch_size def __iter__(self) -> Iterator[int]: while True: sampled_indexes = [] - # Sample pr subjects by sampling labels appeared in dataset - sampled_subjects = random.sample(self.metadata_labels, k=self.pr) + sampled_subjects = random.sample( + self.metadata_labels, k=self.batch_size + ) for label in sampled_subjects: mask = self.labels == label # Fix unbalanced datasets @@ -54,14 +55,7 @@ class TripletSampler(data.Sampler): condition_mask |= self.conditions == condition mask &= condition_mask clips = self.indexes[mask].tolist() - # Sample k clips from the subject without replacement if - # have enough clips, k more clips will sampled for - # disentanglement - k = self.k * 2 - if len(clips) >= k: - _sampled_indexes = random.sample(clips, k=k) - else: - _sampled_indexes = random.choices(clips, k=k) + _sampled_indexes = random.sample(clips, k=2) sampled_indexes += _sampled_indexes yield sampled_indexes |