diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-07 18:37:43 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-07 18:37:43 +0800 |
commit | 4a284084c253b9114fc02e1782962556ff113761 (patch) | |
tree | d6ceff8da68b224186d84772ee6153353675bcfe /utils/triplet_loss.py | |
parent | a27af5dfd58e7b48cf3bd063fa2b4b51ed1e0277 (diff) |
Add typical training script and some bug fixes
1. Resolve deprecated scheduler stepping issue
2. Make losses in the same scale(replace mean with sum in separate triplet loss, enlarge pose similarity loss 10x)
3. Add ReLU when compute distance in triplet loss
4. Remove classes except Model from `models` package init
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r-- | utils/triplet_loss.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 242be45..1d63a0e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,7 +18,9 @@ class BatchAllTripletLoss(nn.Module): x1_squared_sum = x_squared_sum.unsqueeze(1) x2_squared_sum = x_squared_sum.unsqueeze(2) x1_times_x2_sum = x @ x.transpose(1, 2) - dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + dist = torch.sqrt( + F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + ) hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) @@ -31,5 +33,5 @@ class BatchAllTripletLoss(nn.Module): parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1) parted_loss_mean[parted_loss_mean == float('Inf')] = 0 - loss = parted_loss_mean.mean() + loss = parted_loss_mean.sum() return loss |