diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-10 22:38:23 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-10 22:38:23 +0800 |
commit | f7df51490f9c3cc932672493e6c91595686df9ce (patch) | |
tree | 04975117a0f4a6b0d559a38091aba1db51a7a048 /models/rgb_part_net.py | |
parent | 5e00cd7de1729db12329e793a4e84b6c7900a948 (diff) | |
parent | 20110729ab450c84d90965f5b8930236035f093a (diff) |
Merge branch 'python3.8' into python3.7python3.7
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index ffee044..3a251da 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -41,26 +41,26 @@ class RGBPartNet(nn.Module): self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2) + if self.training: + # Step 1: Disentanglement + # n, t, c, h, w + (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - f_c_mean = f_c.mean(1) - x_c = self.hpm(f_c_mean) - # p, n, d + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, t, c, h, w + x_c, f_c_loss = self.hpm(f_c, *f_loss[1]) + # p, n, d / p, n, t, c - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(f_p) - # p, n, d + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # n, t, c, h, w + x_p, f_p_loss = self.pn(f_p, f_loss[2]) + # p, n, d / p, n, t, c - if self.training: i_a, i_c, i_p = None, None, None if self.image_log_on: with torch.no_grad(): f_a_mean = f_a.mean(1) + f_c_mean = f_c.mean(1) i_a = self.ae.decoder( f_a_mean, torch.zeros_like(f_c_mean), @@ -79,15 +79,18 @@ class RGBPartNet(nn.Module): device=f_c.device), f_p.view(-1, *f_p_size[2:]) ).view(x_c1.size()) - return x_c, x_p, ae_losses, (i_a, i_c, i_p) - else: + return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p) + else: # Evaluating + f_c, f_p = self._disentangle(x_c1, x_c2) + x_c = self.hpm(f_c) + x_p = self.pn(f_p) return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :] - features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - return features, losses + features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + return features, f_loss else: # evaluating features = self.ae(x_c1_t2) return features, None |