diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 8 | ||||
-rw-r--r-- | models/rgb_part_net.py | 10 |
2 files changed, 9 insertions, 9 deletions
diff --git a/models/model.py b/models/model.py index 0c3e5eb..139fd59 100644 --- a/models/model.py +++ b/models/model.py @@ -223,16 +223,16 @@ class Model: if self.image_log_on: i_a, i_c, i_p = images self.writer.add_images( + 'Appearance image', i_a, self.curr_iter + ) + self.writer.add_images( 'Canonical image', i_c, self.curr_iter ) - for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)): + for i, (o, p) in enumerate(zip(x_c1, i_p)): self.writer.add_images( f'Original image/batch {i}', o, self.curr_iter ) self.writer.add_images( - f'Appearance image/batch {i}', a, self.curr_iter - ) - self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) time_used = datetime.now() - start_time diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2aa680c..bf52efe 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -101,7 +101,7 @@ class RGBPartNet(nn.Module): i_a, i_c, i_p = None, None, None if self.image_log_on: - i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device) + i_a = self._decode_appr_feature(f_a_, n, t, device) # Continue decoding canonical features i_c = self.ae.decoder.trans_conv3(x_c) i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) @@ -115,14 +115,14 @@ class RGBPartNet(nn.Module): x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) return (x_c, x_p), None, None - def _decode_appr_feature(self, f_a_, n, t, c, h, w, device): + def _decode_appr_feature(self, f_a_, n, t, device): # Decode appearance features - x_a_ = self.ae.decoder( - f_a_, + f_a = f_a_.view(n, t, -1) + x_a = self.ae.decoder( + f_a.mean(1), torch.zeros((n * t, self.f_c_dim), device=device), torch.zeros((n * t, self.f_p_dim), device=device) ) - x_a = x_a_.view(n, t, c, h, w) return x_a def _decode_cano_feature(self, f_c_, n, t, device): |