summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-01 11:26:34 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-01 11:26:34 +0800
commit8afcd7659fba5f15e221eee0158237edda749317 (patch)
tree8f629415e958e299168752e1462bccb95258b0a5 /utils
parent0d2b643d7e04eba872e8b1fc9b04478a026bb3b0 (diff)
parent2f3a7fbef70efd2cf91b7d77b3b71ffb4de907e2 (diff)
Merge branch 'python3.8' into python3.7
Diffstat (limited to 'utils')
-rw-r--r--utils/triplet_loss.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 77c7234..60faa0c 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -68,7 +68,10 @@ class BatchTripletLoss(nn.Module):
@staticmethod
def _all_distance(dist, y, p, n):
- positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ # Unmask identical samples
+ positive_mask = torch.eye(
+ n, dtype=torch.bool, device=y.device
+ ) ^ (y.unsqueeze(1) == y.unsqueeze(2))
negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
all_positive = dist[positive_mask].view(p, n, -1, 1)
all_negative = dist[negative_mask].view(p, n, 1, -1)