diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 15:08:19 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 15:08:19 +0800 |
commit | 2ac1787e4580521848460215e6b06f4bb1648f06 (patch) | |
tree | 75e4121a89b38f69c600711bac9e3734294f7d83 /test/part_net.py | |
parent | 2e6a6d5bda3ddea10afda8e07d2cfe5697a26de3 (diff) |
Unit testing on auto-encoder, HPM and Part Net
Diffstat (limited to 'test/part_net.py')
-rw-r--r-- | test/part_net.py | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/test/part_net.py b/test/part_net.py new file mode 100644 index 0000000..25e92ae --- /dev/null +++ b/test/part_net.py @@ -0,0 +1,71 @@ +import torch + +from models.part_net import FrameLevelPartFeatureExtractor, \ + TemporalFeatureAggregator, PartNet + +T, N, C, H, W = 15, 4, 3, 64, 32 + + +def test_default_fpfe(): + fpfe = FrameLevelPartFeatureExtractor() + x = torch.rand(T, N, C, H, W) + x = fpfe(x) + + assert tuple(x.size()) == (T * N, 32 * 4, 16, 8) + + +def test_custom_fpfe(): + feature_channels = 64 + fpfe = FrameLevelPartFeatureExtractor( + in_channels=1, + feature_channels=feature_channels, + kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), + paddings=((2, 1), (1, 1), (1, 1), (1, 1)), + halving=(1, 1, 3, 3) + ) + x = torch.rand(T, N, 1, H, W) + x = fpfe(x) + + assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4) + + +def test_default_tfa(): + in_channels = 32 * 4 + tfa = TemporalFeatureAggregator(in_channels) + x = torch.rand(16, T, N, in_channels) + x = tfa(x) + + assert tuple(x.size()) == (16, N, in_channels) + + +def test_custom_tfa(): + in_channels = 64 * 8 + num_part = 8 + tfa = TemporalFeatureAggregator(in_channels=in_channels, + squeeze_ratio=8, num_part=num_part) + x = torch.rand(num_part, T, N, in_channels) + x = tfa(x) + + assert tuple(x.size()) == (num_part, N, in_channels) + + +def test_default_part_net(): + pa = PartNet() + x = torch.rand(T, N, C, H, W) + x = pa(x) + + assert tuple(x.size()) == (16, N, 32 * 4) + + +def test_custom_part_net(): + feature_channels = 64 + pa = PartNet(in_channels=1, feature_channels=feature_channels, + kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), + paddings=((2, 1), (1, 1), (1, 1), (1, 1)), + halving=(1, 1, 3, 3), + squeeze_ratio=8, + num_part=8) + x = torch.rand(T, N, 1, H, W) + x = pa(x) + + assert tuple(x.size()) == (8, N, pa.tfa_in_channels) |