diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 40 |
1 files changed, 27 insertions, 13 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 377c108..0ff8251 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_scales: tuple[int, ...] = (1, 2, 4, 8), + hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, fpfe_feature_channels: int = 32, @@ -32,7 +32,7 @@ class RGBPartNet(nn.Module): fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part ) self.hpm = HorizontalPyramidMatching( - ae_in_channels, self.pn.tfa_in_channels, hpm_scales, + ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) @@ -54,38 +54,52 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w num_frames = len(x_c1) - f_c_c1, f_p_c1, f_p_c2 = [], [], [] + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] + # Features required to calculate losses + f_p_c1, f_p_c2 = [], [] xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1) for t2 in range(num_frames): t1 = random.randrange(num_frames) output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) - (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output - (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 - # Features for next step - f_c_c1.append(f_c_c1_t2) - f_p_c1.append(f_p_c1_t2) + (x_c1_t2, f_p_t2, losses) = output + + # Decoded features or image + (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 + # Canonical Features for HPM + x_c_c1.append(x_c_c1_t2) + # Pose image for Part Net + x_p_c1.append(x_p_c1_t2) + # Losses per time step + # Used in pose similarity loss + (f_p_c1_t2, f_p_c2_t2) = f_p_t2 + f_p_c1.append(f_p_c1_t2) f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss + (xrecon_loss_t2, cano_cons_loss_t2) = losses xrecon_loss += xrecon_loss_t2 cano_cons_loss += cano_cons_loss_t2 - f_c_c1 = torch.stack(f_c_c1) - f_p_c1 = torch.stack(f_p_c1) + + x_c_c1 = torch.stack(x_c_c1) + x_p_c1 = torch.stack(x_p_c1) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w - x_c = self.hpm(f_c_c1) + x_c = self.hpm(x_c_c1) # p, t, n, c x_c = x_c.mean(dim=1) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # t, n, c, h, w - x_p = self.pn(f_p_c1) + x_p = self.pn(x_p_c1) # p, n, c # Step 3: Cat feature map together and calculate losses - x = torch.cat(x_c, x_p) + x = torch.cat([x_c, x_p]) # Losses + f_p_c1 = torch.stack(f_p_c1) f_p_c2 = torch.stack(f_p_c2) pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) cano_cons_loss /= num_frames |