diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/triplet_loss.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 242be45..1d63a0e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,7 +18,9 @@ class BatchAllTripletLoss(nn.Module): x1_squared_sum = x_squared_sum.unsqueeze(1) x2_squared_sum = x_squared_sum.unsqueeze(2) x1_times_x2_sum = x @ x.transpose(1, 2) - dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + dist = torch.sqrt( + F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + ) hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) @@ -31,5 +33,5 @@ class BatchAllTripletLoss(nn.Module): parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1) parted_loss_mean[parted_loss_mean == float('Inf')] = 0 - loss = parted_loss_mean.mean() + loss = parted_loss_mean.sum() return loss |