diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index bf52efe..c3954bc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -120,8 +120,8 @@ class RGBPartNet(nn.Module): f_a = f_a_.view(n, t, -1) x_a = self.ae.decoder( f_a.mean(1), - torch.zeros((n * t, self.f_c_dim), device=device), - torch.zeros((n * t, self.f_p_dim), device=device) + torch.zeros((n, self.f_c_dim), device=device), + torch.zeros((n, self.f_p_dim), device=device) ) return x_a |