From 62f14a6ef0d902b9ffd4e57427a40663e2e5c2ad Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 9 Jan 2021 21:02:34 +0800 Subject: Change auto-encoder input in evaluation --- models/rgb_part_net.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 3037da0..456695d 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -51,10 +51,12 @@ class RGBPartNet(nn.Module): def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2, y=None): + def forward(self, x_c1, x_c2=None, y=None): # Step 0: Swap batch_size and time dimensions for next step # n, t, c, h, w - x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) + x_c1 = x_c1.transpose(0, 1) + if self.training: + x_c2 = x_c2.transpose(0, 1) # Step 1: Disentanglement # t, n, c, h, w @@ -84,7 +86,7 @@ class RGBPartNet(nn.Module): else: return x - def _disentangle(self, x_c1, x_c2, y): + def _disentangle(self, x_c1, x_c2=None, y=None): num_frames = len(x_c1) # Decoded canonical features and Pose images x_c_c1, x_p_c1 = [], [] @@ -94,7 +96,7 @@ class RGBPartNet(nn.Module): xrecon_loss, cano_cons_loss = [], [] for t2 in range(num_frames): t1 = random.randrange(num_frames) - output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) (x_c1_t2, f_p_t2, losses) = output # Decoded features or image @@ -127,8 +129,7 @@ class RGBPartNet(nn.Module): else: # evaluating for t2 in range(num_frames): - t1 = random.randrange(num_frames) - x_c1_t2 = self.ae(x_c1[t1], x_c1[t2], x_c2[t2]) + x_c1_t2 = self.ae(x_c1[t2]) # Decoded features or image (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 # Canonical Features for HPM -- cgit v1.2.3 From de911a563fc503114559d7e0e7f710db090cec0d Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 9 Jan 2021 21:54:10 +0800 Subject: Add prototype predict function --- models/rgb_part_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 456695d..f39b40b 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -84,7 +84,7 @@ class RGBPartNet(nn.Module): loss = torch.sum(torch.stack(losses)) return loss, [loss.item() for loss in losses] else: - return x + return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None, y=None): num_frames = len(x_c1) -- cgit v1.2.3