From b294b715ec0de6ba94199f3b068dc828095fd2f1 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Sat, 10 Apr 2021 22:34:25 +0800
Subject: Calculate pose similarity loss and canonical consistency loss of each
 part after pooling

---
 models/auto_encoder.py | 14 +------
 models/hpm.py          | 20 ++++++++--
 models/model.py        | 99 ++++++++++++++++++++++++++++++--------------------
 models/part_net.py     | 20 ++++++----
 models/rgb_part_net.py | 37 ++++++++++---------
 5 files changed, 108 insertions(+), 82 deletions(-)

(limited to 'models')

diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 96dfdb3..dc7843a 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -134,25 +134,13 @@ class AutoEncoder(nn.Module):
             x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
             x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
 
-            xrecon_loss = torch.stack([
-                F.mse_loss(x_c1_t2[:, i], x_c1_t2_pred[:, i])
-                for i in range(t)
-            ]).sum()
-
             f_c_c1_t1 = f_c_c1_t1_.view(f_size[1])
             f_c_c2_t2 = f_c_c2_t2_.view(f_size[1])
-            cano_cons_loss = torch.stack([
-                F.mse_loss(f_c_c1_t1[:, i], f_c_c1_t2[:, i])
-                + F.mse_loss(f_c_c1_t2[:, i], f_c_c2_t2[:, i])
-                for i in range(t)
-            ]).mean()
-
             f_p_c2_t2 = f_p_c2_t2_.view(f_size[2])
-            pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
 
             return (
                 (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
-                (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+                (x_c1_t2_pred, (f_c_c1_t1, f_c_c2_t2), f_p_c2_t2)
             )
         else:  # evaluating
             return f_a_c1_t2, f_c_c1_t2, f_p_c1_t2
diff --git a/models/hpm.py b/models/hpm.py
index 8186b20..fa0f69e 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -33,8 +33,9 @@ class HorizontalPyramidMatching(nn.Module):
         ])
         return pyramid
 
-    def forward(self, x):
-        n, c, h, w = x.size()
+    def _horizontal_pyramid_pool(self, x):
+        n, t, c, h, w = x.size()
+        x = x.view(n * t, c, h, w)
         feature = []
         for scale, pyramid in zip(self.scales, self.pyramids):
             h_per_hpp = h // scale
@@ -43,12 +44,23 @@ class HorizontalPyramidMatching(nn.Module):
                                         (hpp_index + 1) * h_per_hpp)
                 x_slice = x[:, :, h_filter, :]
                 x_slice = hpp(x_slice)
-                x_slice = x_slice.view(n, -1)
+                x_slice = x_slice.view(n, t, c)
                 feature.append(x_slice)
         x = torch.stack(feature)
+        return x
 
+    def forward(self, f_c1_t2, f_c1_t1=None, f_c2_t2=None):
+        # n, t, c, h, w
+        f_c1_t2_ = self._horizontal_pyramid_pool(f_c1_t2)
+        # p, n, t, c
+        x = f_c1_t2_.mean(2)
         # p, n, c
         x = x @ self.fc_mat
         # p, n, d
 
-        return x
+        if self.training:
+            f_c1_t1_ = self._horizontal_pyramid_pool(f_c1_t1)
+            f_c2_t2_ = self._horizontal_pyramid_pool(f_c2_t2)
+            return x, (f_c1_t2_, f_c1_t1_, f_c2_t2_)
+        else:
+            return x
diff --git a/models/model.py b/models/model.py
index 45067e6..6118bdf 100644
--- a/models/model.py
+++ b/models/model.py
@@ -267,11 +267,15 @@ class Model:
             # forward + backward + optimize
             x_c1 = batch_c1['clip'].to(self.device)
             x_c2 = batch_c2['clip'].to(self.device)
-            embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+            embedding, f_loss, images = self.rgb_pn(x_c1, x_c2)
+            ae_losses = self._disentangling_loss(x_c1, f_loss)
             y = batch_c1['label'].to(self.device)
-            losses, hpm_result, pn_result = self._classification_loss(
-                embed_c, embed_p, ae_losses, y
-            )
+            results = self._classification_loss(embedding, y)
+            losses = torch.stack((
+                *ae_losses,
+                results[0]['loss'].mean(),
+                results[1]['loss'].mean()
+            ))
             loss = losses.sum()
             loss.backward()
             self.optimizer.step()
@@ -282,9 +286,7 @@ class Model:
                 'Auto-encoder', 'HPM', 'PartNet'
             ), self.scheduler.get_last_lr())), self.curr_iter)
             # Other stats
-            self._write_stat(
-                'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
-            )
+            self._write_stat('Train', embedding, results, losses)
 
             # Write disentangled images
             if self.image_log_on and self.curr_iter % self.image_log_steps \
@@ -306,8 +308,8 @@ class Model:
 
             if self.curr_iter % 100 == 99:
                 # Validation
-                embed_c = self._flatten_embedding(embed_c)
-                embed_p = self._flatten_embedding(embed_p)
+                embed_c = self._flatten_embedding(embedding[0])
+                embed_p = self._flatten_embedding(embedding[1])
                 self._write_embedding('HPM Train', embed_c, x_c1, y)
                 self._write_embedding('PartNet Train', embed_p, x_c1, y)
 
@@ -316,18 +318,19 @@ class Model:
                 x_c1 = batch_c1['clip'].to(self.device)
                 x_c2 = batch_c2['clip'].to(self.device)
                 with torch.no_grad():
-                    embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
+                    embedding, f_loss, images = self.rgb_pn(x_c1, x_c2)
+                ae_losses = self._disentangling_loss(x_c1, f_loss)
                 y = batch_c1['label'].to(self.device)
-                losses, hpm_result, pn_result = self._classification_loss(
-                    embed_c, embed_p, ae_losses, y
-                )
-                loss = losses.sum()
-
-                self._write_stat(
-                    'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
-                )
-                embed_c = self._flatten_embedding(embed_c)
-                embed_p = self._flatten_embedding(embed_p)
+                results = self._classification_loss(embedding, y)
+                losses = torch.stack((
+                    *ae_losses,
+                    results[0]['loss'].mean(),
+                    results[1]['loss'].mean()
+                ))
+
+                self._write_stat('Val', embedding, results, losses)
+                embed_c = self._flatten_embedding(embedding[0])
+                embed_p = self._flatten_embedding(embedding[1])
                 self._write_embedding('HPM Val', embed_c, x_c1, y)
                 self._write_embedding('PartNet Val', embed_p, x_c1, y)
 
@@ -342,21 +345,39 @@ class Model:
 
         self.writer.close()
 
-    def _classification_loss(self, embed_c, embed_p, ae_losses, y):
+    @staticmethod
+    def _disentangling_loss(x_c1_t2, f_loss):
+        n, t, c, h, w = x_c1_t2.size()
+        x_c1_t2_pred = f_loss[0]
+        xrecon_loss = torch.stack([
+            F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+            for i in range(t)
+        ]).sum()
+        cano_cons_loss = torch.stack([
+            torch.stack([
+                F.mse_loss(_f_c_c1_t1[:, i, :], _f_c_c1_t2[:, i, :])
+                + F.mse_loss(_f_c_c1_t2[:, i, :], _f_c_c2_t2[:, i, :])
+                for i in range(t)
+            ]).mean()
+            for _f_c_c1_t2, _f_c_c1_t1, _f_c_c2_t2 in zip(*f_loss[1])
+        ]).sum()
+        pose_sim_loss = torch.stack([
+            F.mse_loss(_f_p_c1_t2.mean(1), _f_p_c2_t2.mean(1))
+            for _f_p_c1_t2, _f_p_c2_t2 in zip(*f_loss[2])
+        ]).sum()
+
+        return xrecon_loss, cano_cons_loss * 10, pose_sim_loss * 100
+
+    def _classification_loss(self, embedding, y):
         # Duplicate labels for each part
         y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
         hpm_result = self.triplet_loss_hpm(
-            embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
+            embedding[0], y_triplet[:self.rgb_pn.hpm.num_parts]
         )
         pn_result = self.triplet_loss_pn(
-            embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
+            embedding[1], y_triplet[self.rgb_pn.hpm.num_parts:]
         )
-        losses = torch.stack((
-            *ae_losses,
-            hpm_result.pop('loss').mean(),
-            pn_result.pop('loss').mean()
-        ))
-        return losses, hpm_result, pn_result
+        return hpm_result, pn_result
 
     def _write_embedding(self, tag, embed, x, y):
         frame = x[:, 0, :, :, :].cpu()
@@ -374,11 +395,8 @@ class Model:
     def _flatten_embedding(self, embed):
         return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
 
-    def _write_stat(
-            self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
-    ):
+    def _write_stat(self, postfix, embeddings, results, losses):
         # Write losses to TensorBoard
-        self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
         self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
             'Cross reconstruction loss', 'Canonical consistency loss',
             'Pose similarity loss'
@@ -388,30 +406,31 @@ class Model:
             'PartNet': losses[4]
         }, self.curr_iter)
         # None-zero losses in batch
-        if hpm_result['counts'] is not None and pn_result['counts'] is not None:
+        if results[0]['counts'] is not None \
+                and results[1]['counts'] is not None:
             self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
-                'HPM': hpm_result['counts'].mean(),
-                'PartNet': pn_result['counts'].mean()
+                'HPM': results[0]['counts'].mean(),
+                'PartNet': results[1]['counts'].mean()
             }, self.curr_iter)
         # Embedding distance
-        mean_hpm_dist = hpm_result['dist'].mean(0)
+        mean_hpm_dist = results[0]['dist'].mean(0)
         self._add_ranked_scalars(
             f'Embedding/HPM distance {postfix}', mean_hpm_dist,
             self.num_pos_pairs, self.num_pairs, self.curr_iter
         )
-        mean_pn_dist = pn_result['dist'].mean(0)
+        mean_pn_dist = results[1]['dist'].mean(0)
         self._add_ranked_scalars(
             f'Embedding/ParNet distance {postfix}', mean_pn_dist,
             self.num_pos_pairs, self.num_pairs, self.curr_iter
         )
         # Embedding norm
-        mean_hpm_embedding = embed_c.mean(0)
+        mean_hpm_embedding = embeddings[0].mean(0)
         mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
         self._add_ranked_scalars(
             f'Embedding/HPM norm {postfix}', mean_hpm_norm,
             self.k, self.pr * self.k, self.curr_iter
         )
-        mean_pa_embedding = embed_p.mean(0)
+        mean_pa_embedding = embeddings[1].mean(0)
         mean_pa_norm = mean_pa_embedding.norm(dim=-1)
         self._add_ranked_scalars(
             f'Embedding/PartNet norm {postfix}', mean_pa_norm,
diff --git a/models/part_net.py b/models/part_net.py
index f2236bf..65a2c14 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -127,23 +127,27 @@ class PartNet(nn.Module):
             torch.empty(num_parts, in_channels, embedding_dims)
         )
 
-    def forward(self, x):
+    def _horizontal_pool(self, x):
         n, t, c, h, w = x.size()
         x = x.view(n * t, c, h, w)
-        # n * t x c x h x w
-
-        # Horizontal Pooling
-        _, c, h, w = x.size()
         split_size = h // self.num_part
         x = x.split(split_size, dim=2)
         x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
         x = [x_.view(n, t, c) for x_ in x]
         x = torch.stack(x)
+        return x
 
+    def forward(self, f_c1_t2, f_c2_t2=None):
+        # n, t, c, h, w
+        f_c1_t2_ = self._horizontal_pool(f_c1_t2)
         # p, n, t, c
-        x = self.tfa(x)
-
+        x = self.tfa(f_c1_t2_)
         # p, n, c
         x = x @ self.fc_mat
         # p, n, d
-        return x
+
+        if self.training:
+            f_c2_t2_ = self._horizontal_pool(f_c2_t2)
+            return x, (f_c1_t2_, f_c2_t2_)
+        else:
+            return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index b0169e3..06cbf28 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -39,26 +39,26 @@ class RGBPartNet(nn.Module):
         self.num_parts = self.hpm.num_parts + tfa_num_parts
 
     def forward(self, x_c1, x_c2=None):
-        # Step 1: Disentanglement
-        # n, t, c, h, w
-        (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2)
+        if self.training:
+            # Step 1: Disentanglement
+            # n, t, c, h, w
+            (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2)
 
-        # Step 2.a: Static Gait Feature Aggregation & HPM
-        # n, c, h, w
-        f_c_mean = f_c.mean(1)
-        x_c = self.hpm(f_c_mean)
-        # p, n, d
+            # Step 2.a: Static Gait Feature Aggregation & HPM
+            # n, t, c, h, w
+            x_c, f_c_loss = self.hpm(f_c, *f_loss[1])
+            # p, n, d / p, n, t, c
 
-        # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
-        # n, t, c, h, w
-        x_p = self.pn(f_p)
-        # p, n, d
+            # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
+            # n, t, c, h, w
+            x_p, f_p_loss = self.pn(f_p, f_loss[2])
+            # p, n, d / p, n, t, c
 
-        if self.training:
             i_a, i_c, i_p = None, None, None
             if self.image_log_on:
                 with torch.no_grad():
                     f_a_mean = f_a.mean(1)
+                    f_c_mean = f_c.mean(1)
                     i_a = self.ae.decoder(
                         f_a_mean,
                         torch.zeros_like(f_c_mean),
@@ -77,15 +77,18 @@ class RGBPartNet(nn.Module):
                                     device=f_c.device),
                         f_p.view(-1, *f_p_size[2:])
                     ).view(x_c1.size())
-            return x_c, x_p, ae_losses, (i_a, i_c, i_p)
-        else:
+            return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p)
+        else:  # Evaluating
+            f_c, f_p = self._disentangle(x_c1, x_c2)
+            x_c = self.hpm(f_c)
+            x_p = self.pn(f_p)
             return x_c, x_p
 
     def _disentangle(self, x_c1_t2, x_c2_t2=None):
         if self.training:
             x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :]
-            features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
-            return features, losses
+            features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+            return features, f_loss
         else:  # evaluating
             features = self.ae(x_c1_t2)
             return features, None
-- 
cgit v1.2.3