From f1fe77c083f952e81cf80c0b44611fc6057a7882 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 6 Jan 2021 22:19:27 +0800 Subject: Add CUDA support --- models/auto_encoder.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'models/auto_encoder.py') diff --git a/models/auto_encoder.py b/models/auto_encoder.py index eaac2fe..7c1f7ef 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -132,17 +132,14 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - f_a_size, f_c_size, f_p_size = ( - f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() - ) # Decode canonical features for HPM x_c_c1_t2 = self.decoder( - torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size), + torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), no_trans_conv=True ) # Decode pose features for Part Net x_p_c1_t2 = self.decoder( - torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 ) if self.training: -- cgit v1.2.3