summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-06 22:19:27 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-06 22:19:27 +0800
commitf1fe77c083f952e81cf80c0b44611fc6057a7882 (patch)
treeb36dbbdfc21a540bbbfb26b98cfdee0f3652f5c9 /models/auto_encoder.py
parent4befe59046fb3adf8ef8eb589999a74cf7136ff6 (diff)
Add CUDA support
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index eaac2fe..7c1f7ef 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -132,17 +132,14 @@ class AutoEncoder(nn.Module):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- f_a_size, f_c_size, f_p_size = (
- f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
- )
# Decode canonical features for HPM
x_c_c1_t2 = self.decoder(
- torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size),
+ torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2),
no_trans_conv=True
)
# Decode pose features for Part Net
x_p_c1_t2 = self.decoder(
- torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
+ torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2
)
if self.training: