From b294b715ec0de6ba94199f3b068dc828095fd2f1 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 10 Apr 2021 22:34:25 +0800 Subject: Calculate pose similarity loss and canonical consistency loss of each part after pooling --- models/rgb_part_net.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index b0169e3..06cbf28 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -39,26 +39,26 @@ class RGBPartNet(nn.Module): self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2) + if self.training: + # Step 1: Disentanglement + # n, t, c, h, w + (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - f_c_mean = f_c.mean(1) - x_c = self.hpm(f_c_mean) - # p, n, d + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, t, c, h, w + x_c, f_c_loss = self.hpm(f_c, *f_loss[1]) + # p, n, d / p, n, t, c - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(f_p) - # p, n, d + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # n, t, c, h, w + x_p, f_p_loss = self.pn(f_p, f_loss[2]) + # p, n, d / p, n, t, c - if self.training: i_a, i_c, i_p = None, None, None if self.image_log_on: with torch.no_grad(): f_a_mean = f_a.mean(1) + f_c_mean = f_c.mean(1) i_a = self.ae.decoder( f_a_mean, torch.zeros_like(f_c_mean), @@ -77,15 +77,18 @@ class RGBPartNet(nn.Module): device=f_c.device), f_p.view(-1, *f_p_size[2:]) ).view(x_c1.size()) - return x_c, x_p, ae_losses, (i_a, i_c, i_p) - else: + return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p) + else: # Evaluating + f_c, f_p = self._disentangle(x_c1, x_c2) + x_c = self.hpm(f_c) + x_p = self.pn(f_p) return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :] - features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - return features, losses + features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + return features, f_loss else: # evaluating features = self.ae(x_c1_t2) return features, None -- cgit v1.2.3