diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 1cda91c..2853571 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -52,7 +52,7 @@ class RGBPartNet(nn.Module): def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) + ((x_c, x_p), images, f_loss) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w @@ -69,7 +69,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - return x.transpose(0, 1), ae_losses, images + return x.transpose(0, 1), images, f_loss else: return x.unsqueeze(1).view(-1) @@ -78,7 +78,7 @@ class RGBPartNet(nn.Module): device = x_c1_t2.device if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] - ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + (f_a_, f_c_, f_p_), f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features x_c = self._decode_cano_feature(f_c_, n, t, device) x_p_ = self._decode_pose_feature(f_p_, n, t, device) @@ -95,7 +95,7 @@ class RGBPartNet(nn.Module): i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) i_p = i_p_.view(n, t, c, h, w) - return (x_c, x_p), losses, (i_a, i_c, i_p) + return (x_c, x_p), (i_a, i_c, i_p), f_loss else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) |