diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 15:08:19 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 15:08:19 +0800 |
commit | 2ac1787e4580521848460215e6b06f4bb1648f06 (patch) | |
tree | 75e4121a89b38f69c600711bac9e3734294f7d83 /test/hpm.py | |
parent | 2e6a6d5bda3ddea10afda8e07d2cfe5697a26de3 (diff) |
Unit testing on auto-encoder, HPM and Part Net
Diffstat (limited to 'test/hpm.py')
-rw-r--r-- | test/hpm.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/test/hpm.py b/test/hpm.py new file mode 100644 index 0000000..a68337d --- /dev/null +++ b/test/hpm.py @@ -0,0 +1,23 @@ +import torch + +from models import HorizontalPyramidMatching + +T, N, C, H, W = 15, 4, 256, 32, 16 + + +def test_default_hpm(): + hpm = HorizontalPyramidMatching(in_channels=C) + x = torch.rand(T, N, C, H, W) + x = hpm(x) + assert tuple(x.size()) == (1 + 2 + 4, T, N, 128) + + +def test_custom_hpm(): + hpm = HorizontalPyramidMatching(in_channels=2048, + out_channels=256, + scales=(1, 2, 4, 8), + use_avg_pool=True, + use_max_pool=False) + x = torch.rand(T, N, 2048, H, W) + x = hpm(x) + assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256) |