From fed5e6a9b35fda8306147e9ce772dfbf3142a061 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 23:11:05 +0800 Subject: Implement sum of loss default in [1] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [1]A. Hermans, L. Beyer, and B. Leibe, “In defense of the triplet loss for person re-identification,” arXiv preprint arXiv:1703.07737, 2017. --- utils/triplet_loss.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) (limited to 'utils/triplet_loss.py') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 52d676e..db0cf0f 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -9,10 +9,12 @@ class BatchTripletLoss(nn.Module): def __init__( self, is_hard: bool = True, + is_mean: bool = True, margin: Optional[float] = 0.2, ): super().__init__() self.is_hard = is_hard + self.is_mean = is_mean self.margin = margin def forward(self, x, y): @@ -27,13 +29,20 @@ class BatchTripletLoss(nn.Module): positive_negative_dist = self._all_distance(dist, y, p, n) if self.margin: - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return loss_mean, flat_dist, non_zero_counts + losses = F.relu(self.margin + positive_negative_dist).view(p, -1) + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, non_zero_counts else: # Soft margin - all_loss = F.softplus(positive_negative_dist).view(p, -1) - loss_mean = all_loss.mean(1) - return loss_mean, flat_dist, None + losses = F.softplus(positive_negative_dist).view(p, -1) + if self.is_mean: + loss_metric = losses.mean(1) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, None @staticmethod def _batch_distance(x): @@ -68,13 +77,11 @@ class BatchTripletLoss(nn.Module): return positive_negative_dist @staticmethod - def _none_zero_parted_mean(all_loss): + def _none_zero_mean(losses, non_zero_counts): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1).float() - non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 - - return non_zero_mean, non_zero_counts + return non_zero_mean class JointBatchTripletLoss(BatchTripletLoss): @@ -82,9 +89,10 @@ class JointBatchTripletLoss(BatchTripletLoss): self, hpm_num_parts: int, is_hard: bool = True, + is_mean: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__(is_hard) + super().__init__(is_hard, is_mean) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins @@ -103,7 +111,12 @@ class JointBatchTripletLoss(BatchTripletLoss): pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] ) - all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) - return non_zero_mean, dist, non_zero_counts + return loss_metric, dist, non_zero_counts -- cgit v1.2.3