summaryrefslogtreecommitdiff
path: root/test/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-19 22:39:49 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-19 22:39:49 +0800
commitd12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (patch)
tree71b5209ce4b5cfb1d09b89fe133028bbfa481dc9 /test/hpm.py
parent4aa9044122878a8e2b887a8b170c036983431559 (diff)
New branch with auto-encoder only
Diffstat (limited to 'test/hpm.py')
-rw-r--r--test/hpm.py23
1 files changed, 0 insertions, 23 deletions
diff --git a/test/hpm.py b/test/hpm.py
deleted file mode 100644
index 0aefbb8..0000000
--- a/test/hpm.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import torch
-
-from models.hpm import HorizontalPyramidMatching
-
-T, N, C, H, W = 15, 4, 256, 32, 16
-
-
-def test_default_hpm():
- hpm = HorizontalPyramidMatching(in_channels=C)
- x = torch.rand(T, N, C, H, W)
- x = hpm(x)
- assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
-
-
-def test_custom_hpm():
- hpm = HorizontalPyramidMatching(in_channels=2048,
- out_channels=256,
- scales=(1, 2, 4, 8),
- use_avg_pool=True,
- use_max_pool=False)
- x = torch.rand(T, N, 2048, H, W)
- x = hpm(x)
- assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)