summaryrefslogtreecommitdiff
path: root/test/hpm.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/hpm.py')
-rw-r--r--test/hpm.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/test/hpm.py b/test/hpm.py
new file mode 100644
index 0000000..a68337d
--- /dev/null
+++ b/test/hpm.py
@@ -0,0 +1,23 @@
+import torch
+
+from models import HorizontalPyramidMatching
+
+T, N, C, H, W = 15, 4, 256, 32, 16
+
+
+def test_default_hpm():
+ hpm = HorizontalPyramidMatching(in_channels=C)
+ x = torch.rand(T, N, C, H, W)
+ x = hpm(x)
+ assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
+
+
+def test_custom_hpm():
+ hpm = HorizontalPyramidMatching(in_channels=2048,
+ out_channels=256,
+ scales=(1, 2, 4, 8),
+ use_avg_pool=True,
+ use_max_pool=False)
+ x = torch.rand(T, N, 2048, H, W)
+ x = hpm(x)
+ assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)