diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index eaac2fe..7c1f7ef 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -132,17 +132,14 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - f_a_size, f_c_size, f_p_size = ( - f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() - ) # Decode canonical features for HPM x_c_c1_t2 = self.decoder( - torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size), + torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), no_trans_conv=True ) # Decode pose features for Part Net x_p_c1_t2 = self.decoder( - torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 ) if self.training: |