diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 20:30:25 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 20:34:00 +0800 |
commit | 30b475c0a27e0f848743abf0f909607defc6a3ee (patch) | |
tree | aaab163d3d76a835c32ce5014ce62637550d0b0d /utils/triplet_loss.py | |
parent | 3d8fc322623ba61610fd206b9f52b406e85cae61 (diff) | |
parent | e83ae0bcb5c763636fd522c2712a3c8aef558f3c (diff) |
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts:
# models/hpm.py
# models/model.py
# models/rgb_part_net.py
# utils/configuration.py
# utils/triplet_loss.py
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r-- | utils/triplet_loss.py | 42 |
1 files changed, 1 insertions, 41 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index ae899ec..03fff21 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module): non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 return non_zero_mean - - -class JointBatchTripletLoss(BatchTripletLoss): - def __init__( - self, - hpm_num_parts: int, - is_hard: bool = True, - is_mean: bool = True, - margins: Tuple[float, float] = (0.2, 0.2) - ): - super().__init__(is_hard, is_mean) - self.hpm_num_parts = hpm_num_parts - self.margin_hpm, self.margin_pn = margins - - def forward(self, x, y): - p, n, c = x.size() - dist = self._batch_distance(x) - flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) - flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] - - if self.is_hard: - positive_negative_dist = self._hard_distance(dist, y, p, n) - else: # is_all - positive_negative_dist = self._all_distance(dist, y, p, n) - - hpm_part_loss = F.relu( - self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ) - pn_part_loss = F.relu( - self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ) - losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - - non_zero_counts = (losses != 0).sum(1).float() - if self.is_mean: - loss_metric = self._none_zero_mean(losses, non_zero_counts) - else: # is_sum - loss_metric = losses.sum(1) - - return loss_metric, flat_dist, non_zero_counts |