summaryrefslogtreecommitdiff
path: root/utils/triplet_loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r--utils/triplet_loss.py60
1 files changed, 52 insertions, 8 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 954def2..6025bd3 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,3 +1,5 @@
+from typing import Tuple
+
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -11,6 +13,25 @@ class BatchAllTripletLoss(nn.Module):
def forward(self, x, y):
p, n, c = x.size()
+ dist = self._batch_distance(x)
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
+ parted_loss_mean = self._none_zero_parted_mean(all_loss)
+
+ return parted_loss_mean
+
+ @staticmethod
+ def _hard_distance(dist, y, p, n):
+ hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
+ all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
+ positive_negative_dist = all_hard_positive - all_hard_negative
+
+ return positive_negative_dist
+
+ @staticmethod
+ def _batch_distance(x):
# Euclidean distance p x n x n
x_squared_sum = torch.sum(x ** 2, dim=2)
x1_squared_sum = x_squared_sum.unsqueeze(2)
@@ -20,17 +41,40 @@ class BatchAllTripletLoss(nn.Module):
F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
)
- hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
- hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
- all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
- all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
- positive_negative_dist = all_hard_positive - all_hard_negative
- all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
+ return dist
+ @staticmethod
+ def _none_zero_parted_mean(all_loss):
# Non-zero parted mean
non_zero_counts = (all_loss != 0).sum(1)
parted_loss_mean = all_loss.sum(1) / non_zero_counts
parted_loss_mean[non_zero_counts == 0] = 0
- loss = parted_loss_mean.mean()
- return loss
+ return parted_loss_mean
+
+
+class JointBatchAllTripletLoss(BatchAllTripletLoss):
+ def __init__(
+ self,
+ hpm_num_parts: int,
+ margins: Tuple[float, float] = (0.2, 0.2)
+ ):
+ super().__init__()
+ self.hpm_num_parts = hpm_num_parts
+ self.margin_hpm, self.margin_pn = margins
+
+ def forward(self, x, y):
+ p, n, c = x.size()
+
+ dist = self._batch_distance(x)
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ hpm_part_loss = F.relu(
+ self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
+ ).view(self.hpm_num_parts, -1)
+ pn_part_loss = F.relu(
+ self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
+ ).view(p - self.hpm_num_parts, -1)
+ all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
+ parted_loss_mean = self._none_zero_parted_mean(all_loss)
+
+ return parted_loss_mean