diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-10 22:34:25 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-10 22:34:25 +0800 |
commit | b294b715ec0de6ba94199f3b068dc828095fd2f1 (patch) | |
tree | 6b52d1639a80c1800c1fc03dd48c824f92cb0b40 /models/rgb_part_net.py | |
parent | af7faa0f6d1eb3117359f5cf8e4d27a75f3f961c (diff) |
Calculate pose similarity loss and canonical consistency loss of each part after pooling
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index b0169e3..06cbf28 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -39,26 +39,26 @@ class RGBPartNet(nn.Module): self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2) + if self.training: + # Step 1: Disentanglement + # n, t, c, h, w + (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - f_c_mean = f_c.mean(1) - x_c = self.hpm(f_c_mean) - # p, n, d + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, t, c, h, w + x_c, f_c_loss = self.hpm(f_c, *f_loss[1]) + # p, n, d / p, n, t, c - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(f_p) - # p, n, d + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # n, t, c, h, w + x_p, f_p_loss = self.pn(f_p, f_loss[2]) + # p, n, d / p, n, t, c - if self.training: i_a, i_c, i_p = None, None, None if self.image_log_on: with torch.no_grad(): f_a_mean = f_a.mean(1) + f_c_mean = f_c.mean(1) i_a = self.ae.decoder( f_a_mean, torch.zeros_like(f_c_mean), @@ -77,15 +77,18 @@ class RGBPartNet(nn.Module): device=f_c.device), f_p.view(-1, *f_p_size[2:]) ).view(x_c1.size()) - return x_c, x_p, ae_losses, (i_a, i_c, i_p) - else: + return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p) + else: # Evaluating + f_c, f_p = self._disentangle(x_c1, x_c2) + x_c = self.hpm(f_c) + x_p = self.pn(f_p) return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :] - features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - return features, losses + features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + return features, f_loss else: # evaluating features = self.ae(x_c1_t2) return features, None |