diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 456695d..f39b40b 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -84,7 +84,7 @@ class RGBPartNet(nn.Module): loss = torch.sum(torch.stack(losses)) return loss, [loss.item() for loss in losses] else: - return x + return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None, y=None): num_frames = len(x_c1) |