diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-28 22:14:27 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-28 22:32:31 +0800 |
commit | b837336695213e3e660992fcd01c5a52c654ea4f (patch) | |
tree | d1c072f9e69f3c8b6952785b9cc9f219a72d5397 /utils | |
parent | c96a6c88fa63d62ec62807abf957c9a8df307b43 (diff) |
Log n-ile embedding distance and norm
Diffstat (limited to 'utils')
-rw-r--r-- | utils/triplet_loss.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index c3e5802..52d676e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,6 +18,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist = dist.tril(-1) + flat_dist = flat_dist[flat_dist != 0].view(p, -1) if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -26,11 +28,12 @@ class BatchTripletLoss(nn.Module): if self.margin: all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - else: + loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + return loss_mean, flat_dist, non_zero_counts + else: # Soft margin all_loss = F.softplus(positive_negative_dist).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + loss_mean = all_loss.mean(1) + return loss_mean, flat_dist, None @staticmethod def _batch_distance(x): @@ -103,4 +106,4 @@ class JointBatchTripletLoss(BatchTripletLoss): all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + return non_zero_mean, dist, non_zero_counts |