From 6f278a962d70e90ac530f5723e198c7c356e8297 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 9 Jan 2021 20:53:29 +0800 Subject: Fix NaN when separate sum is zero --- utils/triplet_loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 1d63a0e..1899cc9 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -30,8 +30,9 @@ class BatchAllTripletLoss(nn.Module): all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) # Non-zero parted mean - parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1) - parted_loss_mean[parted_loss_mean == float('Inf')] = 0 + non_zero_counts = (all_loss != 0).sum(1) + parted_loss_mean = all_loss.sum(1) / non_zero_counts + parted_loss_mean[non_zero_counts == 0] = 0 loss = parted_loss_mean.sum() return loss -- cgit v1.2.3 From 7188d71b2b6faf3da527c8d0ade9a32ec4893dc5 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 11 Jan 2021 21:15:58 +0800 Subject: Implement evaluator --- utils/triplet_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'utils') diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 1899cc9..8c143d6 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -15,8 +15,8 @@ class BatchAllTripletLoss(nn.Module): # Euclidean distance p x n x n x_squared_sum = torch.sum(x ** 2, dim=2) - x1_squared_sum = x_squared_sum.unsqueeze(1) - x2_squared_sum = x_squared_sum.unsqueeze(2) + x1_squared_sum = x_squared_sum.unsqueeze(2) + x2_squared_sum = x_squared_sum.unsqueeze(1) x1_times_x2_sum = x @ x.transpose(1, 2) dist = torch.sqrt( F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) -- cgit v1.2.3 From 72a53806746bc7ffa2f3939721e34b5cfdb7330a Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 11 Jan 2021 23:59:30 +0800 Subject: Add evaluation script, code review and fix some bugs 1. Add new `train_all` method for one shot calling 2. Print time used in 1k iterations 3. Correct label dimension in predict function 4. Transpose distance matrix for convenient indexing 5. Sort dictionary before generate signature 6. Extract visible CUDA setting function --- utils/misc.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 utils/misc.py (limited to 'utils') diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..b850830 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,10 @@ +import os + +from utils.configuration import SystemConfiguration + + +def set_visible_cuda(config: SystemConfiguration): + """Set environment variable CUDA device(s)""" + CUDA_VISIBLE_DEVICES = config.get('CUDA_VISIBLE_DEVICES', None) + if CUDA_VISIBLE_DEVICES: + os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES -- cgit v1.2.3