diff options
Diffstat (limited to 'models/hpm.py')
-rw-r--r-- | models/hpm.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/models/hpm.py b/models/hpm.py index 6dd58f4..4a1f1a4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -39,6 +39,10 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): + # Flatten frames in all batches + n, t, c, h, w = x.size() + x = x.view(-1, c, h, w) + x = self.backbone(x) n, c, h, w = x.size() @@ -52,6 +56,9 @@ class HorizontalPyramidMatching(nn.Module): x_slice = hpp(x_slice) x_slice = x_slice.view(n, -1) feature.append(x_slice) + x = torch.cat(feature, dim=1) - feature = torch.cat(feature, dim=1) - return feature + # Unfold frames to original batch + _, d = x.size() + x = x.view(n, t, d) + return x |