summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-31 13:39:57 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-31 13:39:57 +0800
commit57275a210b93f9ffd30a53e22c3c28f49f228d14 (patch)
tree21518b80edd2ce2bf1e4036cd13838481c27f835 /models/hpm.py
parent0d307facb16236174c302dc7714f7f20f434a4a6 (diff)
Make HPM capable of processing frames in all batches
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 6dd58f4..4a1f1a4 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -39,6 +39,10 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
+ # Flatten frames in all batches
+ n, t, c, h, w = x.size()
+ x = x.view(-1, c, h, w)
+
x = self.backbone(x)
n, c, h, w = x.size()
@@ -52,6 +56,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = hpp(x_slice)
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
+ x = torch.cat(feature, dim=1)
- feature = torch.cat(feature, dim=1)
- return feature
+ # Unfold frames to original batch
+ _, d = x.size()
+ x = x.view(n, t, d)
+ return x