diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 14:04:33 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 14:04:33 +0800 |
commit | c3e354de4a48e483866b15858c3f3b4b7e53f660 (patch) | |
tree | bd1cb5696391033f91bc3c869d1edf4afd8a4928 /utils | |
parent | 2f3a7fbef70efd2cf91b7d77b3b71ffb4de907e2 (diff) | |
parent | 7318a09451852e3f7d5f68180964f03bd0b0f616 (diff) |
Merge branch 'master' into python3.8
Diffstat (limited to 'utils')
-rw-r--r-- | utils/triplet_loss.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 60faa0c..ae899ec 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -20,8 +20,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) - flat_dist = dist.tril(-1) - flat_dist = flat_dist[flat_dist != 0].view(p, -1) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -102,6 +102,8 @@ class JointBatchTripletLoss(BatchTripletLoss): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -122,4 +124,4 @@ class JointBatchTripletLoss(BatchTripletLoss): else: # is_sum loss_metric = losses.sum(1) - return loss_metric, dist, non_zero_counts + return loss_metric, flat_dist, non_zero_counts |