diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-09 21:02:34 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-09 21:02:34 +0800 |
commit | 62f14a6ef0d902b9ffd4e57427a40663e2e5c2ad (patch) | |
tree | adaa0151474197f21e6f53bf9fed7a0ea149f311 /models/rgb_part_net.py | |
parent | 6f278a962d70e90ac530f5723e198c7c356e8297 (diff) |
Change auto-encoder input in evaluation
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 3037da0..456695d 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -51,10 +51,12 @@ class RGBPartNet(nn.Module): def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2, y=None): + def forward(self, x_c1, x_c2=None, y=None): # Step 0: Swap batch_size and time dimensions for next step # n, t, c, h, w - x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) + x_c1 = x_c1.transpose(0, 1) + if self.training: + x_c2 = x_c2.transpose(0, 1) # Step 1: Disentanglement # t, n, c, h, w @@ -84,7 +86,7 @@ class RGBPartNet(nn.Module): else: return x - def _disentangle(self, x_c1, x_c2, y): + def _disentangle(self, x_c1, x_c2=None, y=None): num_frames = len(x_c1) # Decoded canonical features and Pose images x_c_c1, x_p_c1 = [], [] @@ -94,7 +96,7 @@ class RGBPartNet(nn.Module): xrecon_loss, cano_cons_loss = [], [] for t2 in range(num_frames): t1 = random.randrange(num_frames) - output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) (x_c1_t2, f_p_t2, losses) = output # Decoded features or image @@ -127,8 +129,7 @@ class RGBPartNet(nn.Module): else: # evaluating for t2 in range(num_frames): - t1 = random.randrange(num_frames) - x_c1_t2 = self.ae(x_c1[t1], x_c1[t2], x_c2[t2]) + x_c1_t2 = self.ae(x_c1[t2]) # Decoded features or image (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 # Canonical Features for HPM |