diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-06 22:19:27 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-06 22:19:27 +0800 |
commit | f1fe77c083f952e81cf80c0b44611fc6057a7882 (patch) | |
tree | b36dbbdfc21a540bbbfb26b98cfdee0f3652f5c9 /models/auto_encoder.py | |
parent | 4befe59046fb3adf8ef8eb589999a74cf7136ff6 (diff) |
Add CUDA support
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index eaac2fe..7c1f7ef 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -132,17 +132,14 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - f_a_size, f_c_size, f_p_size = ( - f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() - ) # Decode canonical features for HPM x_c_c1_t2 = self.decoder( - torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size), + torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), no_trans_conv=True ) # Decode pose features for Part Net x_p_c1_t2 = self.decoder( - torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 ) if self.training: |