diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/rgb_part_net.py | 26 |
1 files changed, 10 insertions, 16 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index e707c26..2cc0958 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -86,15 +86,15 @@ class RGBPartNet(nn.Module): return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None, y=None): - num_frames = len(x_c1) - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] + t, n, c, h, w = x_c1.size() if self.training: + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] # Features required to calculate losses f_p_c1, f_p_c2 = [], [] xrecon_loss, cano_cons_loss = [], [] - for t2 in range(num_frames): - t1 = random.randrange(num_frames) + for t2 in range(t): + t1 = random.randrange(t) output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) (x_c1_t2, f_p_t2, losses) = output @@ -127,17 +127,11 @@ class RGBPartNet(nn.Module): (xrecon_loss, pose_sim_loss, cano_cons_loss)) else: # evaluating - for t2 in range(num_frames): - x_c1_t2 = self.ae(x_c1[t2]) - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + x_c1 = x_c1.view(-1, c, h, w) + x_c_c1, x_p_c1 = self.ae(x_c1) + _, c_c, h_c, w_c = x_c_c1.size() + x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c) + x_p_c1 = x_p_c1.view(t, n, c, h, w) return (x_c_c1, x_p_c1), None |