From b6e5972b64cc61fc967cf3d098fc629d781adce4 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Mon, 22 Mar 2021 19:32:16 +0800
Subject: Add embedding visualization and validate on testing set

---
 utils/configuration.py | 1 +
 utils/triplet_loss.py  | 9 +++++++--
 2 files changed, 8 insertions(+), 2 deletions(-)

(limited to 'utils')

diff --git a/utils/configuration.py b/utils/configuration.py
index f6ac182..157d249 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -14,6 +14,7 @@ class DatasetConfiguration(TypedDict):
     name: str
     root_dir: str
     train_size: int
+    val_size: int
     num_sampled_frames: int
     truncate_threshold: int
     discard_threshold: int
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 03fff21..5e3a97a 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module):
         else:  # is_all
             positive_negative_dist = self._all_distance(dist, y, p, n)
 
+        non_zero_counts = None
         if self.margin:
             losses = F.relu(self.margin + positive_negative_dist).view(p, -1)
             non_zero_counts = (losses != 0).sum(1).float()
@@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module):
                 loss_metric = self._none_zero_mean(losses, non_zero_counts)
             else:  # is_sum
                 loss_metric = losses.sum(1)
-            return loss_metric, flat_dist, non_zero_counts
         else:  # Soft margin
             losses = F.softplus(positive_negative_dist).view(p, -1)
             if self.is_mean:
                 loss_metric = losses.mean(1)
             else:  # is_sum
                 loss_metric = losses.sum(1)
-            return loss_metric, flat_dist, None
+
+        return {
+            'loss': loss_metric,
+            'dist': flat_dist,
+            'counts': non_zero_counts
+        }
 
     @staticmethod
     def _batch_distance(x):
-- 
cgit v1.2.3