diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/hpm.py | 23 | ||||
-rw-r--r-- | test/part_net.py | 71 |
2 files changed, 0 insertions, 94 deletions
diff --git a/test/hpm.py b/test/hpm.py deleted file mode 100644 index 0aefbb8..0000000 --- a/test/hpm.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from models.hpm import HorizontalPyramidMatching - -T, N, C, H, W = 15, 4, 256, 32, 16 - - -def test_default_hpm(): - hpm = HorizontalPyramidMatching(in_channels=C) - x = torch.rand(T, N, C, H, W) - x = hpm(x) - assert tuple(x.size()) == (1 + 2 + 4, T, N, 128) - - -def test_custom_hpm(): - hpm = HorizontalPyramidMatching(in_channels=2048, - out_channels=256, - scales=(1, 2, 4, 8), - use_avg_pool=True, - use_max_pool=False) - x = torch.rand(T, N, 2048, H, W) - x = hpm(x) - assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256) diff --git a/test/part_net.py b/test/part_net.py deleted file mode 100644 index fada2c4..0000000 --- a/test/part_net.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -from models.part_net import FrameLevelPartFeatureExtractor, \ - TemporalFeatureAggregator, PartNet - -T, N, C, H, W = 15, 4, 3, 64, 32 - - -def test_default_fpfe(): - fpfe = FrameLevelPartFeatureExtractor() - x = torch.rand(T, N, C, H, W) - x = fpfe(x) - - assert tuple(x.size()) == (T * N, 32 * 4, 16, 8) - - -def test_custom_fpfe(): - feature_channels = 64 - fpfe = FrameLevelPartFeatureExtractor( - in_channels=1, - feature_channels=feature_channels, - kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), - paddings=((2, 1), (1, 1), (1, 1), (1, 1)), - halving=(1, 1, 3, 3) - ) - x = torch.rand(T, N, 1, H, W) - x = fpfe(x) - - assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4) - - -def test_default_tfa(): - in_channels = 32 * 4 - tfa = TemporalFeatureAggregator(in_channels) - x = torch.rand(16, T, N, in_channels) - x = tfa(x) - - assert tuple(x.size()) == (16, N, in_channels) - - -def test_custom_tfa(): - in_channels = 64 * 8 - num_part = 8 - tfa = TemporalFeatureAggregator(in_channels=in_channels, - squeeze_ratio=8, num_part=num_part) - x = torch.rand(num_part, T, N, in_channels) - x = tfa(x) - - assert tuple(x.size()) == (num_part, N, in_channels) - - -def test_default_part_net(): - pa = PartNet() - x = torch.rand(T, N, C, H, W) - x = pa(x) - - assert tuple(x.size()) == (16, N, 32 * 4) - - -def test_custom_part_net(): - feature_channels = 64 - pa = PartNet(in_channels=1, feature_channels=feature_channels, - kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), - paddings=((2, 1), (1, 1), (1, 1), (1, 1)), - halving=(1, 1, 3, 3), - squeeze_ratio=8, - num_parts=8) - x = torch.rand(T, N, 1, H, W) - x = pa(x) - - assert tuple(x.size()) == (8, N, pa.tfa_in_channels) |