diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 11:26:23 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 11:26:23 +0800 |
commit | 2f3a7fbef70efd2cf91b7d77b3b71ffb4de907e2 (patch) | |
tree | b7099202831f014ec27e2d840dc2c659d19df347 /utils/triplet_loss.py | |
parent | e04f54d0bfc8fc711e53561065d772dae1926b64 (diff) | |
parent | db0564967d8cfc03b2d3fe4f7d10eff0867e1771 (diff) |
Merge branch 'master' into python3.8
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r-- | utils/triplet_loss.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 77c7234..60faa0c 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -68,7 +68,10 @@ class BatchTripletLoss(nn.Module): @staticmethod def _all_distance(dist, y, p, n): - positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + # Unmask identical samples + positive_mask = torch.eye( + n, dtype=torch.bool, device=y.device + ) ^ (y.unsqueeze(1) == y.unsqueeze(2)) negative_mask = y.unsqueeze(1) != y.unsqueeze(2) all_positive = dist[positive_mask].view(p, n, -1, 1) all_negative = dist[negative_mask].view(p, n, 1, -1) |