summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 6dd58f4..4a1f1a4 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -39,6 +39,10 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
+ # Flatten frames in all batches
+ n, t, c, h, w = x.size()
+ x = x.view(-1, c, h, w)
+
x = self.backbone(x)
n, c, h, w = x.size()
@@ -52,6 +56,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = hpp(x_slice)
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
+ x = torch.cat(feature, dim=1)
- feature = torch.cat(feature, dim=1)
- return feature
+ # Unfold frames to original batch
+ _, d = x.size()
+ x = x.view(n, t, d)
+ return x