diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/auto_encoder.py | 99 | ||||
-rw-r--r-- | test/hpm.py | 23 | ||||
-rw-r--r-- | test/part_net.py | 71 |
3 files changed, 193 insertions, 0 deletions
diff --git a/test/auto_encoder.py b/test/auto_encoder.py new file mode 100644 index 0000000..5cefb8e --- /dev/null +++ b/test/auto_encoder.py @@ -0,0 +1,99 @@ +import torch + +from models.auto_encoder import Encoder, Decoder, AutoEncoder + +N, C, H, W = 128, 3, 64, 32 + + +def test_default_encoder(): + encoder = Encoder() + x = torch.rand(N, C, H, W) + f_a, f_c, f_p = encoder(x) + + assert tuple(f_a.size()) == (N, 128) + assert tuple(f_c.size()) == (N, 128) + assert tuple(f_p.size()) == (N, 64) + + +def test_custom_encoder(): + output_dims = (64, 64, 32) + encoder = Encoder(in_channels=1, + feature_channels=32, + output_dims=output_dims) + x = torch.rand(N, 1, H, W) + f_a, f_c, f_p = encoder(x) + + assert tuple(f_a.size()) == (N, output_dims[0]) + assert tuple(f_c.size()) == (N, output_dims[1]) + assert tuple(f_p.size()) == (N, output_dims[2]) + + +def test_default_decoder(): + decoder = Decoder() + f_a, f_c, f_p = torch.rand(N, 128), torch.rand(N, 128), torch.rand(N, 64) + + x_trans_conv = decoder(f_a, f_c, f_p) + assert tuple(x_trans_conv.size()) == (N, C, H, W) + x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True) + assert tuple(x_no_trans_conv.size()) == (N, 64 * 8, 4, 2) + + +def test_custom_decoder(): + embedding_dims = (64, 64, 32) + feature_channels = 32 + decoder = Decoder(input_dims=embedding_dims, + feature_channels=feature_channels, + out_channels=1) + f_a, f_c, f_p = (torch.rand(N, embedding_dims[0]), + torch.rand(N, embedding_dims[1]), + torch.rand(N, embedding_dims[2])) + + x_trans_conv = decoder(f_a, f_c, f_p) + assert tuple(x_trans_conv.size()) == (N, 1, H, W) + x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True) + assert tuple(x_no_trans_conv.size()) == (N, feature_channels * 8, 4, 2) + + +def test_default_auto_encoder(): + ae = AutoEncoder() + x = torch.rand(N, C, H, W) + y = torch.randint(74, (N,)) + + ae.train() + ((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y) + assert tuple(x_c.size()) == (N, 64 * 8, 4, 2) + assert tuple(x_p.size()) == (N, C, H, W) + assert tuple(f_p_c1.size()) == tuple(f_p_c2.size()) == (N, 64) + assert tuple(xrecon.size()) == tuple(cano.size()) == () + + ae.eval() + (x_c, x_p) = ae(x, x, x) + assert tuple(x_c.size()) == (N, 64 * 8, 4, 2) + assert tuple(x_p.size()) == (N, C, H, W) + + +def test_custom_auto_encoder(): + num_class = 10 + channels = 1 + embedding_dims = (64, 64, 32) + feature_channels = 32 + ae = AutoEncoder(num_class=num_class, + channels=channels, + feature_channels=feature_channels, + embedding_dims=embedding_dims) + x = torch.rand(N, 1, H, W) + y = torch.randint(num_class, (N,)) + + ae.train() + ((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y) + assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2) + assert tuple(x_p.size()) == (N, 1, H, W) + assert tuple(f_p_c1.size()) \ + == tuple(f_p_c2.size()) \ + == (N, embedding_dims[2]) + assert tuple(xrecon.size()) == tuple(cano.size()) == () + + ae.eval() + (x_c, x_p) = ae(x, x, x) + assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2) + assert tuple(x_p.size()) == (N, 1, H, W) diff --git a/test/hpm.py b/test/hpm.py new file mode 100644 index 0000000..a68337d --- /dev/null +++ b/test/hpm.py @@ -0,0 +1,23 @@ +import torch + +from models import HorizontalPyramidMatching + +T, N, C, H, W = 15, 4, 256, 32, 16 + + +def test_default_hpm(): + hpm = HorizontalPyramidMatching(in_channels=C) + x = torch.rand(T, N, C, H, W) + x = hpm(x) + assert tuple(x.size()) == (1 + 2 + 4, T, N, 128) + + +def test_custom_hpm(): + hpm = HorizontalPyramidMatching(in_channels=2048, + out_channels=256, + scales=(1, 2, 4, 8), + use_avg_pool=True, + use_max_pool=False) + x = torch.rand(T, N, 2048, H, W) + x = hpm(x) + assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256) diff --git a/test/part_net.py b/test/part_net.py new file mode 100644 index 0000000..25e92ae --- /dev/null +++ b/test/part_net.py @@ -0,0 +1,71 @@ +import torch + +from models.part_net import FrameLevelPartFeatureExtractor, \ + TemporalFeatureAggregator, PartNet + +T, N, C, H, W = 15, 4, 3, 64, 32 + + +def test_default_fpfe(): + fpfe = FrameLevelPartFeatureExtractor() + x = torch.rand(T, N, C, H, W) + x = fpfe(x) + + assert tuple(x.size()) == (T * N, 32 * 4, 16, 8) + + +def test_custom_fpfe(): + feature_channels = 64 + fpfe = FrameLevelPartFeatureExtractor( + in_channels=1, + feature_channels=feature_channels, + kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), + paddings=((2, 1), (1, 1), (1, 1), (1, 1)), + halving=(1, 1, 3, 3) + ) + x = torch.rand(T, N, 1, H, W) + x = fpfe(x) + + assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4) + + +def test_default_tfa(): + in_channels = 32 * 4 + tfa = TemporalFeatureAggregator(in_channels) + x = torch.rand(16, T, N, in_channels) + x = tfa(x) + + assert tuple(x.size()) == (16, N, in_channels) + + +def test_custom_tfa(): + in_channels = 64 * 8 + num_part = 8 + tfa = TemporalFeatureAggregator(in_channels=in_channels, + squeeze_ratio=8, num_part=num_part) + x = torch.rand(num_part, T, N, in_channels) + x = tfa(x) + + assert tuple(x.size()) == (num_part, N, in_channels) + + +def test_default_part_net(): + pa = PartNet() + x = torch.rand(T, N, C, H, W) + x = pa(x) + + assert tuple(x.size()) == (16, N, 32 * 4) + + +def test_custom_part_net(): + feature_channels = 64 + pa = PartNet(in_channels=1, feature_channels=feature_channels, + kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), + paddings=((2, 1), (1, 1), (1, 1), (1, 1)), + halving=(1, 1, 3, 3), + squeeze_ratio=8, + num_part=8) + x = torch.rand(T, N, 1, H, W) + x = pa(x) + + assert tuple(x.size()) == (8, N, pa.tfa_in_channels) |