summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py26
1 files changed, 10 insertions, 16 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index e707c26..2cc0958 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -86,15 +86,15 @@ class RGBPartNet(nn.Module):
return x.unsqueeze(1).view(-1)
def _disentangle(self, x_c1, x_c2=None, y=None):
- num_frames = len(x_c1)
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
+ t, n, c, h, w = x_c1.size()
if self.training:
+ # Decoded canonical features and Pose images
+ x_c_c1, x_p_c1 = [], []
# Features required to calculate losses
f_p_c1, f_p_c2 = [], []
xrecon_loss, cano_cons_loss = [], []
- for t2 in range(num_frames):
- t1 = random.randrange(num_frames)
+ for t2 in range(t):
+ t1 = random.randrange(t)
output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y)
(x_c1_t2, f_p_t2, losses) = output
@@ -127,17 +127,11 @@ class RGBPartNet(nn.Module):
(xrecon_loss, pose_sim_loss, cano_cons_loss))
else: # evaluating
- for t2 in range(num_frames):
- x_c1_t2 = self.ae(x_c1[t2])
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ x_c1 = x_c1.view(-1, c, h, w)
+ x_c_c1, x_p_c1 = self.ae(x_c1)
+ _, c_c, h_c, w_c = x_c_c1.size()
+ x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c)
+ x_p_c1 = x_p_c1.view(t, n, c, h, w)
return (x_c_c1, x_p_c1), None