From 820d3dec284f38e6a3089dad5277bc3f6c5123bf Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Sat, 20 Feb 2021 14:19:30 +0800
Subject: Separate triplet loss from model

---
 models/auto_encoder.py |  2 +-
 models/model.py        | 21 ++++++++++++++++++---
 models/rgb_part_net.py | 20 +++-----------------
 3 files changed, 22 insertions(+), 21 deletions(-)

(limited to 'models')

diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 2d715db..1ef7494 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -171,7 +171,7 @@ class AutoEncoder(nn.Module):
 
             return (
                 (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
-                (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+                torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10))
             )
         else:  # evaluating
             return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/model.py b/models/model.py
index 82d6461..5899fc0 100644
--- a/models/model.py
+++ b/models/model.py
@@ -18,6 +18,7 @@ from utils.configuration import DataloaderConfiguration, \
     SystemConfiguration
 from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
 from utils.sampler import TripletSampler
+from utils.triplet_loss import JointBatchAllTripletLoss
 
 
 class Model:
@@ -67,6 +68,7 @@ class Model:
         self._dataset_sig: str = 'undefined'
 
         self.rgb_pn: Optional[RGBPartNet] = None
+        self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
         self.optimizer: Optional[optim.Adam] = None
         self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
         self.writer: Optional[SummaryWriter] = None
@@ -140,7 +142,8 @@ class Model:
         dataset = self._parse_dataset_config(dataset_config)
         dataloader = self._parse_dataloader_config(dataset, dataloader_config)
         # Prepare for model, optimizer and scheduler
-        model_hp = self.hp.get('model', {})
+        model_hp: dict = self.hp.get('model', {}).copy()
+        triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
         optim_hp: dict = self.hp.get('optimizer', {}).copy()
         start_iter = optim_hp.pop('start_iter', 0)
         ae_optim_hp = optim_hp.pop('auto_encoder', {})
@@ -150,8 +153,12 @@ class Model:
         sched_hp = self.hp.get('scheduler', {})
         self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
                                  image_log_on=self.image_log_on)
+        self.ba_triplet_loss = JointBatchAllTripletLoss(
+            self.rgb_pn.hpm_num_parts, triplet_margins
+        )
         # Try to accelerate computation using CUDA or others
         self.rgb_pn = self.rgb_pn.to(self.device)
+        self.ba_triplet_loss = self.ba_triplet_loss.to(self.device)
         self.optimizer = optim.Adam([
             {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
             {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
@@ -193,10 +200,18 @@ class Model:
             # forward + backward + optimize
             x_c1 = batch_c1['clip'].to(self.device)
             x_c2 = batch_c2['clip'].to(self.device)
+            feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
             y = batch_c1['label'].to(self.device)
             # Duplicate labels for each part
-            y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts)
-            losses, images = self.rgb_pn(x_c1, x_c2, y)
+            y = y.repeat(self.rgb_pn.num_total_parts, 1)
+            triplet_loss = self.ba_triplet_loss(feature, y)
+            losses = torch.cat((
+                ae_losses,
+                torch.stack((
+                    triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
+                    triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+                ))
+            ))
             loss = losses.sum()
             loss.backward()
             self.optimizer.step()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 67acac3..408bca0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,6 @@ import torch.nn as nn
 from models.auto_encoder import AutoEncoder
 from models.hpm import HorizontalPyramidMatching
 from models.part_net import PartNet
-from utils.triplet_loss import BatchAllTripletLoss
 
 
 class RGBPartNet(nn.Module):
@@ -25,7 +24,6 @@ class RGBPartNet(nn.Module):
             tfa_squeeze_ratio: int = 4,
             tfa_num_parts: int = 16,
             embedding_dims: int = 256,
-            triplet_margins: tuple[float, float] = (0.2, 0.2),
             image_log_on: bool = False
     ):
         super().__init__()
@@ -50,17 +48,13 @@ class RGBPartNet(nn.Module):
                                out_channels, embedding_dims)
         self.fc_mat = nn.Parameter(empty_fc)
 
-        (hpm_margin, pn_margin) = triplet_margins
-        self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
-        self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
-
     def fc(self, x):
         return x @ self.fc_mat
 
-    def forward(self, x_c1, x_c2=None, y=None):
+    def forward(self, x_c1, x_c2=None):
         # Step 1: Disentanglement
         # n, t, c, h, w
-        ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+        ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
 
         # Step 2.a: Static Gait Feature Aggregation & HPM
         # n, c, h, w
@@ -77,15 +71,7 @@ class RGBPartNet(nn.Module):
         x = self.fc(x)
 
         if self.training:
-            y = y.T
-            hpm_ba_trip = self.hpm_ba_trip(
-                x[:self.hpm_num_parts], y[:self.hpm_num_parts]
-            )
-            pn_ba_trip = self.pn_ba_trip(
-                x[self.hpm_num_parts:], y[self.hpm_num_parts:]
-            )
-            losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
-            return losses, images
+            return x, ae_losses, images
         else:
             return x.unsqueeze(1).view(-1)
 
-- 
cgit v1.2.3