From 57275a210b93f9ffd30a53e22c3c28f49f228d14 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 31 Dec 2020 13:39:57 +0800 Subject: Make HPM capable of processing frames in all batches --- models/hpm.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'models/hpm.py') diff --git a/models/hpm.py b/models/hpm.py index 6dd58f4..4a1f1a4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -39,6 +39,10 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): + # Flatten frames in all batches + n, t, c, h, w = x.size() + x = x.view(-1, c, h, w) + x = self.backbone(x) n, c, h, w = x.size() @@ -52,6 +56,9 @@ class HorizontalPyramidMatching(nn.Module): x_slice = hpp(x_slice) x_slice = x_slice.view(n, -1) feature.append(x_slice) + x = torch.cat(feature, dim=1) - feature = torch.cat(feature, dim=1) - return feature + # Unfold frames to original batch + _, d = x.size() + x = x.view(n, t, d) + return x -- cgit v1.2.3