summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-01 14:04:54 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-01 14:04:54 +0800
commitd8b2ac1f91c28ad2c79caf9bdcd54789f6523732 (patch)
treec979c0b38c2ae193008ddef4b5c42a045dbd68e2
parent8afcd7659fba5f15e221eee0158237edda749317 (diff)
parentc3e354de4a48e483866b15858c3f3b4b7e53f660 (diff)
Merge branch 'python3.8' into python3.7
-rw-r--r--utils/triplet_loss.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 60faa0c..ae899ec 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -20,8 +20,8 @@ class BatchTripletLoss(nn.Module):
def forward(self, x, y):
p, n, c = x.size()
dist = self._batch_distance(x)
- flat_dist = dist.tril(-1)
- flat_dist = flat_dist[flat_dist != 0].view(p, -1)
+ flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
+ flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
if self.is_hard:
positive_negative_dist = self._hard_distance(dist, y, p, n)
@@ -102,6 +102,8 @@ class JointBatchTripletLoss(BatchTripletLoss):
def forward(self, x, y):
p, n, c = x.size()
dist = self._batch_distance(x)
+ flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
+ flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
if self.is_hard:
positive_negative_dist = self._hard_distance(dist, y, p, n)
@@ -122,4 +124,4 @@ class JointBatchTripletLoss(BatchTripletLoss):
else: # is_sum
loss_metric = losses.sum(1)
- return loss_metric, dist, non_zero_counts
+ return loss_metric, flat_dist, non_zero_counts