From 4bdc37bbd86a83647bbbda7bd1367c08e6c6f6d4 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 11:20:34 +0800 Subject: Remove identical sample in Batch All case --- utils/triplet_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index db0cf0f..6822cf6 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -68,7 +68,10 @@ class BatchTripletLoss(nn.Module): @staticmethod def _all_distance(dist, y, p, n): - positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + # Unmask identical samples + positive_mask = torch.eye( + n, dtype=torch.bool, device=y.device + ) ^ (y.unsqueeze(1) == y.unsqueeze(2)) negative_mask = y.unsqueeze(1) != y.unsqueeze(2) all_positive = dist[positive_mask].view(p, n, -1, 1) all_negative = dist[negative_mask].view(p, n, 1, -1) -- cgit v1.2.3