summaryrefslogtreecommitdiff
path: root/test/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-03 15:08:19 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-03 15:08:19 +0800
commit2ac1787e4580521848460215e6b06f4bb1648f06 (patch)
tree75e4121a89b38f69c600711bac9e3734294f7d83 /test/hpm.py
parent2e6a6d5bda3ddea10afda8e07d2cfe5697a26de3 (diff)
Unit testing on auto-encoder, HPM and Part Net
Diffstat (limited to 'test/hpm.py')
-rw-r--r--test/hpm.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/test/hpm.py b/test/hpm.py
new file mode 100644
index 0000000..a68337d
--- /dev/null
+++ b/test/hpm.py
@@ -0,0 +1,23 @@
+import torch
+
+from models import HorizontalPyramidMatching
+
+T, N, C, H, W = 15, 4, 256, 32, 16
+
+
+def test_default_hpm():
+ hpm = HorizontalPyramidMatching(in_channels=C)
+ x = torch.rand(T, N, C, H, W)
+ x = hpm(x)
+ assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
+
+
+def test_custom_hpm():
+ hpm = HorizontalPyramidMatching(in_channels=2048,
+ out_channels=256,
+ scales=(1, 2, 4, 8),
+ use_avg_pool=True,
+ use_max_pool=False)
+ x = torch.rand(T, N, 2048, H, W)
+ x = hpm(x)
+ assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)