summaryrefslogtreecommitdiff
path: root/utils/triplet_loss.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-22 20:50:30 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-22 20:50:30 +0800
commit4a4506fde7867df4d4d446c8992621fb0da03939 (patch)
treec2cedb06e538c6c66cc8fe9dd7f47e00b66cc98e /utils/triplet_loss.py
parentf9f45c8d32c6f99c314f1ca93140598d1005c8fb (diff)
parentb6e5972b64cc61fc967cf3d098fc629d781adce4 (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/model.py
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r--utils/triplet_loss.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 03fff21..5e3a97a 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module):
else: # is_all
positive_negative_dist = self._all_distance(dist, y, p, n)
+ non_zero_counts = None
if self.margin:
losses = F.relu(self.margin + positive_negative_dist).view(p, -1)
non_zero_counts = (losses != 0).sum(1).float()
@@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module):
loss_metric = self._none_zero_mean(losses, non_zero_counts)
else: # is_sum
loss_metric = losses.sum(1)
- return loss_metric, flat_dist, non_zero_counts
else: # Soft margin
losses = F.softplus(positive_negative_dist).view(p, -1)
if self.is_mean:
loss_metric = losses.mean(1)
else: # is_sum
loss_metric = losses.sum(1)
- return loss_metric, flat_dist, None
+
+ return {
+ 'loss': loss_metric,
+ 'dist': flat_dist,
+ 'counts': non_zero_counts
+ }
@staticmethod
def _batch_distance(x):