From ab29067d6469473481cc73fe42bcaf69d7633a83 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Tue, 5 Jan 2021 20:20:06 +0800
Subject: Implement Batch All Triplet Loss

---
 models/rgb_part_net.py | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

(limited to 'models')

diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 5012765..a58be39 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -5,6 +5,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from utils.triplet_loss import BatchAllTripletLoss
 
 
 class RGBPartNet(nn.Module):
@@ -24,7 +25,7 @@ class RGBPartNet(nn.Module):
             tfa_squeeze_ratio: int = 4,
             tfa_num_parts: int = 16,
             embedding_dims: int = 256,
-            triplet_margin: int = 0.2
+            triplet_margin: float = 0.2
     ):
         super().__init__()
         self.ae = AutoEncoder(
@@ -43,8 +44,10 @@ class RGBPartNet(nn.Module):
         empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
         self.fc_mat = nn.Parameter(empty_fc)
 
+        self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+
     def fc(self, x):
-        return torch.matmul(x, self.fc_mat)
+        return x @ self.fc_mat
 
     def forward(self, x_c1, x_c2, y=None):
         # Step 0: Swap batch_size and time dimensions for next step
@@ -72,10 +75,10 @@ class RGBPartNet(nn.Module):
         x = self.fc(x)
 
         if self.training:
-            # TODO Implement Batch All triplet loss function
-            batch_all_triplet_loss = torch.tensor(0.)
-            loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss)))
-            return loss
+            batch_all_triplet_loss = self.ba_triplet_loss(x, y)
+            losses = (*losses, batch_all_triplet_loss)
+            loss = torch.sum(torch.stack(losses))
+            return loss, (loss.item() for loss in losses)
         else:
             return x
 
-- 
cgit v1.2.3