From 820d3dec284f38e6a3089dad5277bc3f6c5123bf Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 20 Feb 2021 14:19:30 +0800 Subject: Separate triplet loss from model --- utils/triplet_loss.py | 58 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 8 deletions(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 954def2..0df2188 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -11,6 +11,25 @@ class BatchAllTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() + dist = self._batch_distance(x) + positive_negative_dist = self._hard_distance(dist, y, p, n) + all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + parted_loss_mean = self._none_zero_parted_mean(all_loss) + + return parted_loss_mean + + @staticmethod + def _hard_distance(dist, y, p, n): + hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) + all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) + positive_negative_dist = all_hard_positive - all_hard_negative + + return positive_negative_dist + + @staticmethod + def _batch_distance(x): # Euclidean distance p x n x n x_squared_sum = torch.sum(x ** 2, dim=2) x1_squared_sum = x_squared_sum.unsqueeze(2) @@ -20,17 +39,40 @@ class BatchAllTripletLoss(nn.Module): F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) ) - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + return dist + @staticmethod + def _none_zero_parted_mean(all_loss): # Non-zero parted mean non_zero_counts = (all_loss != 0).sum(1) parted_loss_mean = all_loss.sum(1) / non_zero_counts parted_loss_mean[non_zero_counts == 0] = 0 - loss = parted_loss_mean.mean() - return loss + return parted_loss_mean + + +class JointBatchAllTripletLoss(BatchAllTripletLoss): + def __init__( + self, + hpm_num_parts: int, + margins: tuple[float, float] = (0.2, 0.2) + ): + super().__init__() + self.hpm_num_parts = hpm_num_parts + self.margin_hpm, self.margin_pn = margins + + def forward(self, x, y): + p, n, c = x.size() + + dist = self._batch_distance(x) + positive_negative_dist = self._hard_distance(dist, y, p, n) + hpm_part_loss = F.relu( + self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] + ).view(self.hpm_num_parts, -1) + pn_part_loss = F.relu( + self.margin_pn + positive_negative_dist[self.hpm_num_parts:] + ).view(p - self.hpm_num_parts, -1) + all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + parted_loss_mean = self._none_zero_parted_mean(all_loss) + + return parted_loss_mean -- cgit v1.2.3 From 46391257ff50848efa1aa251ab3f15dc8b7a2d2c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 27 Feb 2021 22:14:21 +0800 Subject: Implement Batch Hard triplet loss and soft margin --- utils/configuration.py | 1 + utils/triplet_loss.py | 84 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 57 insertions(+), 28 deletions(-) (limited to 'utils') diff --git a/utils/configuration.py b/utils/configuration.py index 435d815..20aec76 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -43,6 +43,7 @@ class ModelHPConfiguration(TypedDict): tfa_squeeze_ratio: int tfa_num_parts: int embedding_dims: int + triplet_is_hard: bool triplet_margins: tuple[float, float] diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 0df2188..c3e5802 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -1,32 +1,36 @@ +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -class BatchAllTripletLoss(nn.Module): - def __init__(self, margin: float = 0.2): +class BatchTripletLoss(nn.Module): + def __init__( + self, + is_hard: bool = True, + margin: Optional[float] = 0.2, + ): super().__init__() + self.is_hard = is_hard self.margin = margin def forward(self, x, y): p, n, c = x.size() - dist = self._batch_distance(x) - positive_negative_dist = self._hard_distance(dist, y, p, n) - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - parted_loss_mean = self._none_zero_parted_mean(all_loss) - return parted_loss_mean + if self.is_hard: + positive_negative_dist = self._hard_distance(dist, y, p, n) + else: # is_all + positive_negative_dist = self._all_distance(dist, y, p, n) - @staticmethod - def _hard_distance(dist, y, p, n): - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative + if self.margin: + all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + else: + all_loss = F.softplus(positive_negative_dist).view(p, -1) + non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return positive_negative_dist + return non_zero_mean, dist.mean((1, 2)), non_zero_counts @staticmethod def _batch_distance(x): @@ -38,41 +42,65 @@ class BatchAllTripletLoss(nn.Module): dist = torch.sqrt( F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) ) - return dist + @staticmethod + def _hard_distance(dist, y, p, n): + positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + hard_positive = dist[positive_mask].view(p, n, -1).max(-1).values + hard_negative = dist[negative_mask].view(p, n, -1).min(-1).values + positive_negative_dist = hard_positive - hard_negative + + return positive_negative_dist + + @staticmethod + def _all_distance(dist, y, p, n): + positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + all_positive = dist[positive_mask].view(p, n, -1, 1) + all_negative = dist[negative_mask].view(p, n, 1, -1) + positive_negative_dist = all_positive - all_negative + + return positive_negative_dist + @staticmethod def _none_zero_parted_mean(all_loss): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1) - parted_loss_mean = all_loss.sum(1) / non_zero_counts - parted_loss_mean[non_zero_counts == 0] = 0 + non_zero_counts = (all_loss != 0).sum(1).float() + non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean[non_zero_counts == 0] = 0 - return parted_loss_mean + return non_zero_mean, non_zero_counts -class JointBatchAllTripletLoss(BatchAllTripletLoss): +class JointBatchTripletLoss(BatchTripletLoss): def __init__( self, hpm_num_parts: int, + is_hard: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__() + super().__init__(is_hard) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins def forward(self, x, y): p, n, c = x.size() - dist = self._batch_distance(x) - positive_negative_dist = self._hard_distance(dist, y, p, n) + + if self.is_hard: + positive_negative_dist = self._hard_distance(dist, y, p, n) + else: # is_all + positive_negative_dist = self._all_distance(dist, y, p, n) + hpm_part_loss = F.relu( self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ).view(self.hpm_num_parts, -1) + ) pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ).view(p - self.hpm_num_parts, -1) + ) all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - parted_loss_mean = self._none_zero_parted_mean(all_loss) + non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return parted_loss_mean + return non_zero_mean, dist.mean((1, 2)), non_zero_counts -- cgit v1.2.3 From b837336695213e3e660992fcd01c5a52c654ea4f Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 22:14:27 +0800 Subject: Log n-ile embedding distance and norm --- utils/triplet_loss.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index c3e5802..52d676e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,6 +18,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist = dist.tril(-1) + flat_dist = flat_dist[flat_dist != 0].view(p, -1) if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -26,11 +28,12 @@ class BatchTripletLoss(nn.Module): if self.margin: all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - else: + loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + return loss_mean, flat_dist, non_zero_counts + else: # Soft margin all_loss = F.softplus(positive_negative_dist).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + loss_mean = all_loss.mean(1) + return loss_mean, flat_dist, None @staticmethod def _batch_distance(x): @@ -103,4 +106,4 @@ class JointBatchTripletLoss(BatchTripletLoss): all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + return non_zero_mean, dist, non_zero_counts -- cgit v1.2.3 From fed5e6a9b35fda8306147e9ce772dfbf3142a061 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 23:11:05 +0800 Subject: Implement sum of loss default in [1] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [1]A. Hermans, L. Beyer, and B. Leibe, “In defense of the triplet loss for person re-identification,” arXiv preprint arXiv:1703.07737, 2017. --- utils/configuration.py | 1 + utils/triplet_loss.py | 43 ++++++++++++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 15 deletions(-) (limited to 'utils') diff --git a/utils/configuration.py b/utils/configuration.py index 20aec76..31eb243 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -44,6 +44,7 @@ class ModelHPConfiguration(TypedDict): tfa_num_parts: int embedding_dims: int triplet_is_hard: bool + triplet_is_mean: bool triplet_margins: tuple[float, float] diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 52d676e..db0cf0f 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -9,10 +9,12 @@ class BatchTripletLoss(nn.Module): def __init__( self, is_hard: bool = True, + is_mean: bool = True, margin: Optional[float] = 0.2, ): super().__init__() self.is_hard = is_hard + self.is_mean = is_mean self.margin = margin def forward(self, x, y): @@ -27,13 +29,20 @@ class BatchTripletLoss(nn.Module): positive_negative_dist = self._all_distance(dist, y, p, n) if self.margin: - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return loss_mean, flat_dist, non_zero_counts + losses = F.relu(self.margin + positive_negative_dist).view(p, -1) + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, non_zero_counts else: # Soft margin - all_loss = F.softplus(positive_negative_dist).view(p, -1) - loss_mean = all_loss.mean(1) - return loss_mean, flat_dist, None + losses = F.softplus(positive_negative_dist).view(p, -1) + if self.is_mean: + loss_metric = losses.mean(1) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, None @staticmethod def _batch_distance(x): @@ -68,13 +77,11 @@ class BatchTripletLoss(nn.Module): return positive_negative_dist @staticmethod - def _none_zero_parted_mean(all_loss): + def _none_zero_mean(losses, non_zero_counts): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1).float() - non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 - - return non_zero_mean, non_zero_counts + return non_zero_mean class JointBatchTripletLoss(BatchTripletLoss): @@ -82,9 +89,10 @@ class JointBatchTripletLoss(BatchTripletLoss): self, hpm_num_parts: int, is_hard: bool = True, + is_mean: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__(is_hard) + super().__init__(is_hard, is_mean) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins @@ -103,7 +111,12 @@ class JointBatchTripletLoss(BatchTripletLoss): pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] ) - all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) - return non_zero_mean, dist, non_zero_counts + return loss_metric, dist, non_zero_counts -- cgit v1.2.3 From 4bdc37bbd86a83647bbbda7bd1367c08e6c6f6d4 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 11:20:34 +0800 Subject: Remove identical sample in Batch All case --- utils/triplet_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index db0cf0f..6822cf6 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -68,7 +68,10 @@ class BatchTripletLoss(nn.Module): @staticmethod def _all_distance(dist, y, p, n): - positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + # Unmask identical samples + positive_mask = torch.eye( + n, dtype=torch.bool, device=y.device + ) ^ (y.unsqueeze(1) == y.unsqueeze(2)) negative_mask = y.unsqueeze(1) != y.unsqueeze(2) all_positive = dist[positive_mask].view(p, n, -1, 1) all_negative = dist[negative_mask].view(p, n, 1, -1) -- cgit v1.2.3 From 7318a09451852e3f7d5f68180964f03bd0b0f616 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 14:04:02 +0800 Subject: Change flat distance calculation method --- utils/triplet_loss.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 6822cf6..e05b69d 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -20,8 +20,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) - flat_dist = dist.tril(-1) - flat_dist = flat_dist[flat_dist != 0].view(p, -1) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -102,6 +102,8 @@ class JointBatchTripletLoss(BatchTripletLoss): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -122,4 +124,4 @@ class JointBatchTripletLoss(BatchTripletLoss): else: # is_sum loss_metric = losses.sum(1) - return loss_metric, dist, non_zero_counts + return loss_metric, flat_dist, non_zero_counts -- cgit v1.2.3 From 6002b2d2017912f90e8917e6e8b71b78ce58e7c2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 18:20:38 +0800 Subject: New scheduler and new config --- utils/configuration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'utils') diff --git a/utils/configuration.py b/utils/configuration.py index 31eb243..0f8d9ff 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -57,7 +57,6 @@ class SubOptimizerHPConfiguration(TypedDict): class OptimizerHPConfiguration(TypedDict): - start_iter: int lr: int betas: tuple[float, float] eps: float @@ -70,8 +69,8 @@ class OptimizerHPConfiguration(TypedDict): class SchedulerHPConfiguration(TypedDict): - step_size: int - gamma: float + start_step: int + final_gamma: float class HyperparameterConfiguration(TypedDict): -- cgit v1.2.3 From 1b8d1614168ce6590c5e029c7f1007ac9b17048c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Mar 2021 14:10:24 +0800 Subject: Bug fixes 1. Resolve reference problems when parsing dataset selectors 2. Transform gallery using different models --- utils/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'utils') diff --git a/utils/dataset.py b/utils/dataset.py index c487988..387c211 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -111,9 +111,9 @@ class CASIAB(data.Dataset): # in Bag #2 condition from 90 degree angle classes, conditions, views = [], [], [] if selector: - selected_classes = selector.pop('classes', None) - selected_conditions = selector.pop('conditions', None) - selected_views = selector.pop('views', None) + selected_classes = selector.get('classes', None) + selected_conditions = selector.get('conditions', None) + selected_views = selector.get('views', None) class_regex = r'\d{3}' condition_regex = r'(nm|bg|cl)-0[0-6]' -- cgit v1.2.3 From c74df416b00f837ba051f3947be92f76e7afbd88 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 13:56:17 +0800 Subject: Code refactoring 1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers --- utils/configuration.py | 20 +++++++++++--------- utils/triplet_loss.py | 40 ---------------------------------------- 2 files changed, 11 insertions(+), 49 deletions(-) (limited to 'utils') diff --git a/utils/configuration.py b/utils/configuration.py index 0f8d9ff..f6ac182 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: tuple[int, int, int] hpm_scales: tuple[int, ...] - hpm_use_1x1conv: bool hpm_use_avg_pool: bool hpm_use_max_pool: bool - fpfe_feature_channels: int - fpfe_kernel_sizes: tuple[tuple, ...] - fpfe_paddings: tuple[tuple, ...] - fpfe_halving: tuple[int, ...] - tfa_squeeze_ratio: int tfa_num_parts: int - embedding_dims: int + tfa_squeeze_ratio: int + embedding_dims: tuple[int] triplet_is_hard: bool triplet_is_mean: bool triplet_margins: tuple[float, float] @@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict): weight_decay: float amsgrad: bool auto_encoder: SubOptimizerHPConfiguration - part_net: SubOptimizerHPConfiguration hpm: SubOptimizerHPConfiguration - fc: SubOptimizerHPConfiguration + part_net: SubOptimizerHPConfiguration + + +class SubSchedulerHPConfiguration(TypedDict): + start_step: int + final_gamma: float class SchedulerHPConfiguration(TypedDict): start_step: int final_gamma: float + auto_encoder: SubSchedulerHPConfiguration + hpm: SubSchedulerHPConfiguration + part_net: SubSchedulerHPConfiguration class HyperparameterConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index e05b69d..03fff21 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module): non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 return non_zero_mean - - -class JointBatchTripletLoss(BatchTripletLoss): - def __init__( - self, - hpm_num_parts: int, - is_hard: bool = True, - is_mean: bool = True, - margins: tuple[float, float] = (0.2, 0.2) - ): - super().__init__(is_hard, is_mean) - self.hpm_num_parts = hpm_num_parts - self.margin_hpm, self.margin_pn = margins - - def forward(self, x, y): - p, n, c = x.size() - dist = self._batch_distance(x) - flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) - flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] - - if self.is_hard: - positive_negative_dist = self._hard_distance(dist, y, p, n) - else: # is_all - positive_negative_dist = self._all_distance(dist, y, p, n) - - hpm_part_loss = F.relu( - self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ) - pn_part_loss = F.relu( - self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ) - losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - - non_zero_counts = (losses != 0).sum(1).float() - if self.is_mean: - loss_metric = self._none_zero_mean(losses, non_zero_counts) - else: # is_sum - loss_metric = losses.sum(1) - - return loss_metric, flat_dist, non_zero_counts -- cgit v1.2.3 From 263b8001ca1b25a43d1c87f187423054e141925d Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 14 Mar 2021 21:07:03 +0800 Subject: Fix unbalanced datasets --- utils/sampler.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) (limited to 'utils') diff --git a/utils/sampler.py b/utils/sampler.py index cdf1984..0c9872c 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -16,7 +16,18 @@ class TripletSampler(data.Sampler): ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] + metadata_conditions = data_source.metadata['conditions'] + self.subsets = {} + for condition in metadata_conditions: + pre, _ = condition.split('-') + if self.subsets.get(pre, None) is None: + self.subsets[pre] = [] + self.subsets[pre].append(condition) + self.num_subsets = len(self.subsets) + self.num_seq = {pre: len(seq) for (pre, seq) in self.subsets.items()} + self.min_num_seq = min(self.num_seq.values()) self.labels = data_source.labels + self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) (self.pr, self.k) = batch_size @@ -27,15 +38,31 @@ class TripletSampler(data.Sampler): # Sample pr subjects by sampling labels appeared in dataset sampled_subjects = random.sample(self.metadata_labels, k=self.pr) for label in sampled_subjects: - clips_from_subject = self.indexes[self.labels == label].tolist() + mask = self.labels == label + # Fix unbalanced datasets + if self.num_subsets > 1: + condition_mask = np.zeros(self.conditions.shape, dtype=bool) + for num, conditions_ in zip( + self.num_seq.values(), self.subsets.values() + ): + if num > self.min_num_seq: + conditions = random.sample( + conditions_, self.min_num_seq + ) + else: + conditions = conditions_ + for condition in conditions: + condition_mask |= self.conditions == condition + mask &= condition_mask + clips = self.indexes[mask].tolist() # Sample k clips from the subject without replacement if # have enough clips, k more clips will sampled for # disentanglement k = self.k * 2 - if len(clips_from_subject) >= k: - _sampled_indexes = random.sample(clips_from_subject, k=k) + if len(clips) >= k: + _sampled_indexes = random.sample(clips, k=k) else: - _sampled_indexes = random.choices(clips_from_subject, k=k) + _sampled_indexes = random.choices(clips, k=k) sampled_indexes += _sampled_indexes yield sampled_indexes -- cgit v1.2.3 From b6e5972b64cc61fc967cf3d098fc629d781adce4 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 22 Mar 2021 19:32:16 +0800 Subject: Add embedding visualization and validate on testing set --- utils/configuration.py | 1 + utils/triplet_loss.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'utils') diff --git a/utils/configuration.py b/utils/configuration.py index f6ac182..157d249 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -14,6 +14,7 @@ class DatasetConfiguration(TypedDict): name: str root_dir: str train_size: int + val_size: int num_sampled_frames: int truncate_threshold: int discard_threshold: int diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 03fff21..5e3a97a 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module): else: # is_all positive_negative_dist = self._all_distance(dist, y, p, n) + non_zero_counts = None if self.margin: losses = F.relu(self.margin + positive_negative_dist).view(p, -1) non_zero_counts = (losses != 0).sum(1).float() @@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module): loss_metric = self._none_zero_mean(losses, non_zero_counts) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, non_zero_counts else: # Soft margin losses = F.softplus(positive_negative_dist).view(p, -1) if self.is_mean: loss_metric = losses.mean(1) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, None + + return { + 'loss': loss_metric, + 'dist': flat_dist, + 'counts': non_zero_counts + } @staticmethod def _batch_distance(x): -- cgit v1.2.3 From 5a063855dbecb8f1a86ad25d9e61a9c8b63312b3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 25 Mar 2021 12:23:23 +0800 Subject: Bug fixes and refactoring 1. Correct trained model signature 2. Move `val_size` to system config --- utils/configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'utils') diff --git a/utils/configuration.py b/utils/configuration.py index 157d249..5a5bc0c 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -8,13 +8,13 @@ class SystemConfiguration(TypedDict): CUDA_VISIBLE_DEVICES: str save_dir: str image_log_on: bool + val_size: int class DatasetConfiguration(TypedDict): name: str root_dir: str train_size: int - val_size: int num_sampled_frames: int truncate_threshold: int discard_threshold: int -- cgit v1.2.3