diff options
Diffstat (limited to 'test/hpm.py')
-rw-r--r-- | test/hpm.py | 23 |
1 files changed, 0 insertions, 23 deletions
diff --git a/test/hpm.py b/test/hpm.py deleted file mode 100644 index 0aefbb8..0000000 --- a/test/hpm.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from models.hpm import HorizontalPyramidMatching - -T, N, C, H, W = 15, 4, 256, 32, 16 - - -def test_default_hpm(): - hpm = HorizontalPyramidMatching(in_channels=C) - x = torch.rand(T, N, C, H, W) - x = hpm(x) - assert tuple(x.size()) == (1 + 2 + 4, T, N, 128) - - -def test_custom_hpm(): - hpm = HorizontalPyramidMatching(in_channels=2048, - out_channels=256, - scales=(1, 2, 4, 8), - use_avg_pool=True, - use_max_pool=False) - x = torch.rand(T, N, 2048, H, W) - x = hpm(x) - assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256) |