summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-09 21:54:10 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-09 21:54:10 +0800
commitde911a563fc503114559d7e0e7f710db090cec0d (patch)
treeb0720a7bf4a5de4b2c7d529e8a45f5c50dd023fd /models/rgb_part_net.py
parent62f14a6ef0d902b9ffd4e57427a40663e2e5c2ad (diff)
Add prototype predict function
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 456695d..f39b40b 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -84,7 +84,7 @@ class RGBPartNet(nn.Module):
loss = torch.sum(torch.stack(losses))
return loss, [loss.item() for loss in losses]
else:
- return x
+ return x.unsqueeze(1).view(-1)
def _disentangle(self, x_c1, x_c2=None, y=None):
num_frames = len(x_c1)