diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 13:39:57 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 13:39:57 +0800 |
commit | 57275a210b93f9ffd30a53e22c3c28f49f228d14 (patch) | |
tree | 21518b80edd2ce2bf1e4036cd13838481c27f835 /models | |
parent | 0d307facb16236174c302dc7714f7f20f434a4a6 (diff) |
Make HPM capable of processing frames in all batches
Diffstat (limited to 'models')
-rw-r--r-- | models/hpm.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/models/hpm.py b/models/hpm.py index 6dd58f4..4a1f1a4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -39,6 +39,10 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): + # Flatten frames in all batches + n, t, c, h, w = x.size() + x = x.view(-1, c, h, w) + x = self.backbone(x) n, c, h, w = x.size() @@ -52,6 +56,9 @@ class HorizontalPyramidMatching(nn.Module): x_slice = hpp(x_slice) x_slice = x_slice.view(n, -1) feature.append(x_slice) + x = torch.cat(feature, dim=1) - feature = torch.cat(feature, dim=1) - return feature + # Unfold frames to original batch + _, d = x.size() + x = x.view(n, t, d) + return x |