From 9f3d4cc14ad36e515b56e86fb8e26f519bde831e Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Tue, 9 Feb 2021 16:51:08 +0800
Subject: Some optimizations

1. Scheduler will decay the learning rate of auto-encoder only
2. Write learning rate history to tensorboard
3. Reduce image log frequency
---
 models/model.py | 76 +++++++++++++++++++++++++++++++++------------------------
 1 file changed, 44 insertions(+), 32 deletions(-)

diff --git a/models/model.py b/models/model.py
index 0418070..ee07615 100644
--- a/models/model.py
+++ b/models/model.py
@@ -153,8 +153,18 @@ class Model:
         self.rgb_pn = self.rgb_pn.to(self.device)
         self.optimizer = optim.Adam([
             {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
+            {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+            {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
+            {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
         ], **optim_hp)
-        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
+        sched_gamma = sched_hp.get('gamma', 0.9)
+        sched_step_size = sched_hp.get('step_size', 500)
+        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
+            lambda epoch: sched_gamma ** (epoch // sched_step_size),
+            lambda epoch: 0 if epoch < start_iter else 1,
+            lambda epoch: 0 if epoch < start_iter else 1,
+            lambda epoch: 0 if epoch < start_iter else 1,
+        ])
         self.writer = SummaryWriter(self._log_name)
 
         self.rgb_pn.train()
@@ -172,18 +182,8 @@ class Model:
         start_time = datetime.now()
         running_loss = torch.zeros(5, device=self.device)
         print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
-              f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
+              f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
         for (batch_c1, batch_c2) in dataloader:
-            if self.curr_iter == start_iter:
-                self.optimizer.add_param_group(
-                    {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}
-                )
-                self.optimizer.add_param_group(
-                    {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}
-                )
-                self.optimizer.add_param_group(
-                    {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
-                )
             self.curr_iter += 1
             # Zero the parameter gradients
             self.optimizer.zero_grad()
@@ -205,31 +205,39 @@ class Model:
                 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
                 'Batch All triplet loss (PartNet)'
             ], losses)), self.curr_iter)
-            if self.image_log_on:
-                (appearance_image, canonical_image, pose_image) = images
-                self.writer.add_images(
-                    'Canonical image', canonical_image, self.curr_iter
-                )
-                for i in range(self.pr * self.k):
-                    self.writer.add_images(
-                        f'Original image/batch {i}', x_c1[i], self.curr_iter
-                    )
-                    self.writer.add_images(
-                        f'Appearance image/batch {i}',
-                        appearance_image[:, i, :, :, :],
-                        self.curr_iter
-                    )
-                    self.writer.add_images(
-                        f'Pose image/batch {i}',
-                        pose_image[:, i, :, :, :],
-                        self.curr_iter
-                    )
 
             if self.curr_iter % 100 == 0:
                 lrs = self.scheduler.get_last_lr()
+                # Write learning rates
+                self.writer.add_scalar(
+                    'Learning rate/Auto-encoder', lrs[0], self.curr_iter
+                )
+                self.writer.add_scalar(
+                    'Learning rate/Others', lrs[1], self.curr_iter
+                )
+                # Write disentangled images
+                if self.image_log_on:
+                    (appearance_image, canonical_image, pose_image) = images
+                    self.writer.add_images(
+                        'Canonical image', canonical_image, self.curr_iter
+                    )
+                    for i in range(self.pr * self.k):
+                        self.writer.add_images(
+                            f'Original image/batch {i}', x_c1[i], self.curr_iter
+                        )
+                        self.writer.add_images(
+                            f'Appearance image/batch {i}',
+                            appearance_image[:, i, :, :, :],
+                            self.curr_iter
+                        )
+                        self.writer.add_images(
+                            f'Pose image/batch {i}',
+                            pose_image[:, i, :, :, :],
+                            self.curr_iter
+                        )
                 print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
                       '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
-                      ' '.join(('{:.3e}'.format(lr) for lr in lrs)))
+                      '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
                 running_loss.zero_()
 
             # Step scheduler
@@ -244,6 +252,10 @@ class Model:
                 }, self._checkpoint_name)
                 print(datetime.now() - start_time, 'used')
                 start_time = datetime.now()
+                print(
+                    f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
+                    f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}"
+                )
 
             if self.curr_iter == self.total_iter:
                 self.writer.close()
-- 
cgit v1.2.3


From 916cf90d04e57fee23092c966740fbe94fd92cff Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Tue, 9 Feb 2021 21:21:57 +0800
Subject: Improve performance when disentangling

This is a HUGE performance optimization, up to 2x faster than before. Mainly because of the replacement of randomized for-loop with randomized tensor.
---
 models/auto_encoder.py |  43 +++++++++----
 models/model.py        |  37 +++++------
 models/part_net.py     |  17 +++--
 models/rgb_part_net.py | 168 ++++++++++++++++---------------------------------
 4 files changed, 108 insertions(+), 157 deletions(-)

diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index f04ffdb..a9312dd 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -121,23 +121,42 @@ class AutoEncoder(nn.Module):
         self.decoder = Decoder(embedding_dims, feature_channels, channels)
 
     def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
+        n, t, c, h, w = x_c1_t2.size()
         # x_c1_t2 is the frame for later module
-        (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
+        x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
+        (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
 
         if self.training:
             # t1 is random time step, c2 is another condition
-            (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
-            (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
-
-            x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
-            xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_)
-            cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2)
-                                 + F.mse_loss(f_c_c1_t2, f_c_c2_t2))
+            x_c1_t1 = x_c1_t1.view(n * t, c, h, w)
+            (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1)
+            x_c2_t2 = x_c2_t2.view(n * t, c, h, w)
+            (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2)
+
+            x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
+            x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
+
+            xrecon_loss = torch.stack([
+                F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+                for i in range(t)
+            ]).sum()
+
+            f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
+            f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
+            f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
+            cano_cons_loss = torch.stack([
+                F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+                + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+                for i in range(t)
+            ]).mean()
+
+            f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
+            f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
+            pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
 
             return (
-                (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
-                (f_p_c1_t2, f_p_c2_t2),
-                (xrecon_loss_t2, cano_cons_loss_t2)
+                (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
+                (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
             )
         else:  # evaluating
-            return f_c_c1_t2, f_p_c1_t2
+            return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/model.py b/models/model.py
index ee07615..70c43b3 100644
--- a/models/model.py
+++ b/models/model.py
@@ -181,8 +181,9 @@ class Model:
         # Training start
         start_time = datetime.now()
         running_loss = torch.zeros(5, device=self.device)
-        print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
-              f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+        print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
+              f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
+              f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
         for (batch_c1, batch_c2) in dataloader:
             self.curr_iter += 1
             # Zero the parameter gradients
@@ -201,8 +202,8 @@ class Model:
             # Write losses to TensorBoard
             self.writer.add_scalar('Loss/all', loss, self.curr_iter)
             self.writer.add_scalars('Loss/details', dict(zip([
-                'Cross reconstruction loss', 'Pose similarity loss',
-                'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+                'Cross reconstruction loss', 'Canonical consistency loss',
+                'Pose similarity loss', 'Batch All triplet loss (HPM)',
                 'Batch All triplet loss (PartNet)'
             ], losses)), self.curr_iter)
 
@@ -217,25 +218,25 @@ class Model:
                 )
                 # Write disentangled images
                 if self.image_log_on:
-                    (appearance_image, canonical_image, pose_image) = images
+                    i_a, i_c, i_p = images
                     self.writer.add_images(
-                        'Canonical image', canonical_image, self.curr_iter
+                        'Canonical image', i_c, self.curr_iter
                     )
-                    for i in range(self.pr * self.k):
+                    for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
                         self.writer.add_images(
-                            f'Original image/batch {i}', x_c1[i], self.curr_iter
+                            f'Original image/batch {i}', o, self.curr_iter
                         )
                         self.writer.add_images(
-                            f'Appearance image/batch {i}',
-                            appearance_image[:, i, :, :, :],
-                            self.curr_iter
+                            f'Appearance image/batch {i}', a, self.curr_iter
                         )
                         self.writer.add_images(
-                            f'Pose image/batch {i}',
-                            pose_image[:, i, :, :, :],
-                            self.curr_iter
+                            f'Pose image/batch {i}', p, self.curr_iter
                         )
-                print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
+                time_used = datetime.now() - start_time
+                remaining_minute, second = divmod(time_used.seconds, 60)
+                hour, minute = divmod(remaining_minute, 60)
+                print(f'{hour:02}:{minute:02}:{second:02}',
+                      f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
                       '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
                       '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
                 running_loss.zero_()
@@ -250,12 +251,6 @@ class Model:
                     'optim_state_dict': self.optimizer.state_dict(),
                     'loss': loss,
                 }, self._checkpoint_name)
-                print(datetime.now() - start_time, 'used')
-                start_time = datetime.now()
-                print(
-                    f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
-                    f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}"
-                )
 
             if self.curr_iter == self.total_iter:
                 self.writer.close()
diff --git a/models/part_net.py b/models/part_net.py
index ac7c434..62a2bac 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -30,8 +30,8 @@ class FrameLevelPartFeatureExtractor(nn.Module):
 
     def forward(self, x):
         # Flatten frames in all batches
-        t, n, c, h, w = x.size()
-        x = x.view(-1, c, h, w)
+        n, t, c, h, w = x.size()
+        x = x.view(n * t, c, h, w)
 
         for fconv_block in self.fconv_blocks:
             x = fconv_block(x)
@@ -75,8 +75,8 @@ class TemporalFeatureAggregator(nn.Module):
                               for _ in range(self.num_part)])
 
     def forward(self, x):
-        # p, t, n, c
-        x = x.permute(0, 2, 3, 1).contiguous()
+        # p, n, t, c
+        x = x.transpose(2, 3)
         p, n, c, t = x.size()
         feature = x.split(1, dim=0)
         feature = [f.squeeze(0) for f in feature]
@@ -134,19 +134,18 @@ class PartNet(nn.Module):
         self.max_pool = nn.AdaptiveMaxPool2d(1)
 
     def forward(self, x):
-        t, n, _, _, _ = x.size()
-        # t, n, c, h, w
+        n, t, _, _, _ = x.size()
         x = self.fpfe(x)
-        # t_n, c, h, w
+        # n * t x c x h x w
 
         # Horizontal Pooling
         _, c, h, w = x.size()
         split_size = h // self.num_part
         x = x.split(split_size, dim=2)
         x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
-        x = [x_.view(t, n, c) for x_ in x]
+        x = [x_.view(n, t, c) for x_ in x]
         x = torch.stack(x)
 
-        # p, t, n, c
+        # p, n, t, c
         x = self.tfa(x)
         return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 0e7d8b3..8ebcfd3 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -1,8 +1,5 @@
-import random
-
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 
 from models.auto_encoder import AutoEncoder
 from models.hpm import HorizontalPyramidMatching
@@ -59,24 +56,18 @@ class RGBPartNet(nn.Module):
         return x @ self.fc_mat
 
     def forward(self, x_c1, x_c2=None, y=None):
-        # Step 0: Swap batch_size and time dimensions for next step
-        # n, t, c, h, w
-        x_c1 = x_c1.transpose(0, 1)
-        if self.training:
-            x_c2 = x_c2.transpose(0, 1)
-
         # Step 1: Disentanglement
-        # t, n, c, h, w
-        ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
+        # n, t, c, h, w
+        ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
 
         # Step 2.a: Static Gait Feature Aggregation & HPM
         # n, c, h, w
-        x_c = self.hpm(x_c_c1)
+        x_c = self.hpm(x_c)
         # p, n, c
 
         # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
-        # t, n, c, h, w
-        x_p = self.pn(x_p_c1)
+        # n, t, c, h, w
+        x_p = self.pn(x_p)
         # p, n, c
 
         # Step 3: Cat feature map together and fc
@@ -91,113 +82,60 @@ class RGBPartNet(nn.Module):
         else:
             return x.unsqueeze(1).view(-1)
 
-    def _disentangle(self, x_c1, x_c2=None):
-        t, n, c, h, w = x_c1.size()
-        device = x_c1.device
+    def _disentangle(self, x_c1_t2, x_c2_t2=None):
+        n, t, c, h, w = x_c1_t2.size()
+        device = x_c1_t2.device
+        x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
         if self.training:
-            # Encoded appearance, canonical and pose features
-            f_a_c1, f_c_c1, f_p_c1 = [], [], []
-            # Features required to calculate losses
-            f_p_c2 = []
-            xrecon_loss, cano_cons_loss = [], []
-            for t2 in range(t):
-                t1 = random.randrange(t)
-                output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
-                (f_c1_t2, f_p_t2, losses) = output
-
-                (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
-                if self.image_log_on:
-                    f_a_c1.append(f_a_c1_t2)
-                # Save canonical features and pose features
-                f_c_c1.append(f_c_c1_t2)
-                f_p_c1.append(f_p_c1_t2)
-
-                # Losses per time step
-                # Used in pose similarity loss
-                (_, f_p_c2_t2) = f_p_t2
-                f_p_c2.append(f_p_c2_t2)
-
-                # Cross reconstruction loss and canonical loss
-                (xrecon_loss_t2, cano_cons_loss_t2) = losses
-                xrecon_loss.append(xrecon_loss_t2)
-                cano_cons_loss.append(cano_cons_loss_t2)
-            if self.image_log_on:
-                f_a_c1 = torch.stack(f_a_c1)
-            f_c_c1_mean = torch.stack(f_c_c1).mean(0)
-            f_p_c1 = torch.stack(f_p_c1)
-            f_p_c2 = torch.stack(f_p_c2)
-
+            ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
             # Decode features
-            appearance_image, canonical_image, pose_image = None, None, None
             with torch.no_grad():
-                # Decode average canonical features to higher dimension
-                x_c_c1 = self.ae.decoder(
-                    torch.zeros((n, self.f_a_dim), device=device),
-                    f_c_c1_mean,
-                    torch.zeros((n, self.f_p_dim), device=device),
-                    cano_only=True
-                )
-                # Decode pose features to images
-                f_p_c1_ = f_p_c1.view(t * n, -1)
-                x_p_c1_ = self.ae.decoder(
-                    torch.zeros((t * n, self.f_a_dim), device=device),
-                    torch.zeros((t * n, self.f_c_dim), device=device),
-                    f_p_c1_
-                )
-                x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+                x_c = self._decode_cano_feature(f_c_, n, t, device)
+                x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
 
+                i_a, i_c, i_p = None, None, None
                 if self.image_log_on:
-                    # Decode appearance features
-                    f_a_c1_ = f_a_c1.view(t * n, -1)
-                    appearance_image_ = self.ae.decoder(
-                        f_a_c1_,
-                        torch.zeros((t * n, self.f_c_dim), device=device),
-                        torch.zeros((t * n, self.f_p_dim), device=device)
-                    )
-                    appearance_image = appearance_image_.view(t, n, c, h, w)
+                    i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
                     # Continue decoding canonical features
-                    canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
-                    canonical_image = torch.sigmoid(
-                        self.ae.decoder.trans_conv4(canonical_image)
-                    )
-                    pose_image = x_p_c1
-
-            # Losses
-            xrecon_loss = torch.sum(torch.stack(xrecon_loss))
-            pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
-            cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
+                    i_c = self.ae.decoder.trans_conv3(x_c)
+                    i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
+                    i_p = x_p
 
-            return ((x_c_c1, x_p_c1),
-                    (appearance_image, canonical_image, pose_image),
-                    (xrecon_loss, pose_sim_loss, cano_cons_loss))
+            return (x_c, x_p), losses, (i_a, i_c, i_p)
 
         else:  # evaluating
-            x_c1_ = x_c1.view(t * n, c, h, w)
-            (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
-
-            # Canonical features
-            f_c_c1 = f_c_c1_.view(t, n, -1)
-            f_c_c1_mean = f_c_c1.mean(0)
-            x_c_c1 = self.ae.decoder(
-                torch.zeros((n, self.f_a_dim)),
-                f_c_c1_mean,
-                torch.zeros((n, self.f_p_dim)),
-                cano_only=True
-            )
-
-            # Pose features
-            x_p_c1_ = self.ae.decoder(
-                torch.zeros((t * n, self.f_a_dim)),
-                torch.zeros((t * n, self.f_c_dim)),
-                f_p_c1_
-            )
-            x_p_c1 = x_p_c1_.view(t, n, c, h, w)
-
-            return (x_c_c1, x_p_c1), None, None
-
-    @staticmethod
-    def _pose_sim_loss(f_p_c1: torch.Tensor,
-                       f_p_c2: torch.Tensor) -> torch.Tensor:
-        f_p_c1_mean = f_p_c1.mean(dim=0)
-        f_p_c2_mean = f_p_c2.mean(dim=0)
-        return F.mse_loss(f_p_c1_mean, f_p_c2_mean)
+            f_c_, f_p_ = self.ae(x_c1_t2)
+            x_c = self._decode_cano_feature(f_c_, n, t, device)
+            x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+            return (x_c, x_p), None, None
+
+    def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
+        # Decode appearance features
+        x_a_ = self.ae.decoder(
+            f_a_,
+            torch.zeros((n * t, self.f_c_dim), device=device),
+            torch.zeros((n * t, self.f_p_dim), device=device)
+        )
+        x_a = x_a_.view(n, t, c, h, w)
+        return x_a
+
+    def _decode_cano_feature(self, f_c_, n, t, device):
+        # Decode average canonical features to higher dimension
+        f_c = f_c_.view(n, t, -1)
+        x_c = self.ae.decoder(
+            torch.zeros((n, self.f_a_dim), device=device),
+            f_c.mean(1),
+            torch.zeros((n, self.f_p_dim), device=device),
+            cano_only=True
+        )
+        return x_c
+
+    def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
+        # Decode pose features to images
+        x_p_ = self.ae.decoder(
+            torch.zeros((n * t, self.f_a_dim), device=device),
+            torch.zeros((n * t, self.f_c_dim), device=device),
+            f_p_
+        )
+        x_p = x_p_.view(n, t, c, h, w)
+        return x_p
-- 
cgit v1.2.3