summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/rgb_part_net.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index bf52efe..c3954bc 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -120,8 +120,8 @@ class RGBPartNet(nn.Module):
f_a = f_a_.view(n, t, -1)
x_a = self.ae.decoder(
f_a.mean(1),
- torch.zeros((n * t, self.f_c_dim), device=device),
- torch.zeros((n * t, self.f_p_dim), device=device)
+ torch.zeros((n, self.f_c_dim), device=device),
+ torch.zeros((n, self.f_p_dim), device=device)
)
return x_a