From 02aaefaba26b6842d2feb403edfd71aaa75904da Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 2 Jan 2021 19:10:08 +0800 Subject: Correct feature dims after disentanglement and HPM backbone removal 1. Features used in HPM is decoded canonical embedding without transpose convolution 2. Decode pose embedding to image for Part Net 3. Backbone seems to be redundant, we can use feature map given by auto-decoder --- models/rgb_part_net.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 377c108..0ff8251 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_scales: tuple[int, ...] = (1, 2, 4, 8), + hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, fpfe_feature_channels: int = 32, @@ -32,7 +32,7 @@ class RGBPartNet(nn.Module): fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part ) self.hpm = HorizontalPyramidMatching( - ae_in_channels, self.pn.tfa_in_channels, hpm_scales, + ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) @@ -54,38 +54,52 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w num_frames = len(x_c1) - f_c_c1, f_p_c1, f_p_c2 = [], [], [] + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] + # Features required to calculate losses + f_p_c1, f_p_c2 = [], [] xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1) for t2 in range(num_frames): t1 = random.randrange(num_frames) output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) - (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output - (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 - # Features for next step - f_c_c1.append(f_c_c1_t2) - f_p_c1.append(f_p_c1_t2) + (x_c1_t2, f_p_t2, losses) = output + + # Decoded features or image + (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 + # Canonical Features for HPM + x_c_c1.append(x_c_c1_t2) + # Pose image for Part Net + x_p_c1.append(x_p_c1_t2) + # Losses per time step + # Used in pose similarity loss + (f_p_c1_t2, f_p_c2_t2) = f_p_t2 + f_p_c1.append(f_p_c1_t2) f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss + (xrecon_loss_t2, cano_cons_loss_t2) = losses xrecon_loss += xrecon_loss_t2 cano_cons_loss += cano_cons_loss_t2 - f_c_c1 = torch.stack(f_c_c1) - f_p_c1 = torch.stack(f_p_c1) + + x_c_c1 = torch.stack(x_c_c1) + x_p_c1 = torch.stack(x_p_c1) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w - x_c = self.hpm(f_c_c1) + x_c = self.hpm(x_c_c1) # p, t, n, c x_c = x_c.mean(dim=1) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # t, n, c, h, w - x_p = self.pn(f_p_c1) + x_p = self.pn(x_p_c1) # p, n, c # Step 3: Cat feature map together and calculate losses - x = torch.cat(x_c, x_p) + x = torch.cat([x_c, x_p]) # Losses + f_p_c1 = torch.stack(f_p_c1) f_p_c2 = torch.stack(f_p_c2) pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) cano_cons_loss /= num_frames -- cgit v1.2.3