diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 15:08:19 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 15:08:19 +0800 | 
| commit | 2ac1787e4580521848460215e6b06f4bb1648f06 (patch) | |
| tree | 75e4121a89b38f69c600711bac9e3734294f7d83 /test/auto_encoder.py | |
| parent | 2e6a6d5bda3ddea10afda8e07d2cfe5697a26de3 (diff) | |
Unit testing on auto-encoder, HPM and Part Net
Diffstat (limited to 'test/auto_encoder.py')
| -rw-r--r-- | test/auto_encoder.py | 99 | 
1 files changed, 99 insertions, 0 deletions
| diff --git a/test/auto_encoder.py b/test/auto_encoder.py new file mode 100644 index 0000000..5cefb8e --- /dev/null +++ b/test/auto_encoder.py @@ -0,0 +1,99 @@ +import torch + +from models.auto_encoder import Encoder, Decoder, AutoEncoder + +N, C, H, W = 128, 3, 64, 32 + + +def test_default_encoder(): +    encoder = Encoder() +    x = torch.rand(N, C, H, W) +    f_a, f_c, f_p = encoder(x) + +    assert tuple(f_a.size()) == (N, 128) +    assert tuple(f_c.size()) == (N, 128) +    assert tuple(f_p.size()) == (N, 64) + + +def test_custom_encoder(): +    output_dims = (64, 64, 32) +    encoder = Encoder(in_channels=1, +                      feature_channels=32, +                      output_dims=output_dims) +    x = torch.rand(N, 1, H, W) +    f_a, f_c, f_p = encoder(x) + +    assert tuple(f_a.size()) == (N, output_dims[0]) +    assert tuple(f_c.size()) == (N, output_dims[1]) +    assert tuple(f_p.size()) == (N, output_dims[2]) + + +def test_default_decoder(): +    decoder = Decoder() +    f_a, f_c, f_p = torch.rand(N, 128), torch.rand(N, 128), torch.rand(N, 64) + +    x_trans_conv = decoder(f_a, f_c, f_p) +    assert tuple(x_trans_conv.size()) == (N, C, H, W) +    x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True) +    assert tuple(x_no_trans_conv.size()) == (N, 64 * 8, 4, 2) + + +def test_custom_decoder(): +    embedding_dims = (64, 64, 32) +    feature_channels = 32 +    decoder = Decoder(input_dims=embedding_dims, +                      feature_channels=feature_channels, +                      out_channels=1) +    f_a, f_c, f_p = (torch.rand(N, embedding_dims[0]), +                     torch.rand(N, embedding_dims[1]), +                     torch.rand(N, embedding_dims[2])) + +    x_trans_conv = decoder(f_a, f_c, f_p) +    assert tuple(x_trans_conv.size()) == (N, 1, H, W) +    x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True) +    assert tuple(x_no_trans_conv.size()) == (N, feature_channels * 8, 4, 2) + + +def test_default_auto_encoder(): +    ae = AutoEncoder() +    x = torch.rand(N, C, H, W) +    y = torch.randint(74, (N,)) + +    ae.train() +    ((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y) +    assert tuple(x_c.size()) == (N, 64 * 8, 4, 2) +    assert tuple(x_p.size()) == (N, C, H, W) +    assert tuple(f_p_c1.size()) == tuple(f_p_c2.size()) == (N, 64) +    assert tuple(xrecon.size()) == tuple(cano.size()) == () + +    ae.eval() +    (x_c, x_p) = ae(x, x, x) +    assert tuple(x_c.size()) == (N, 64 * 8, 4, 2) +    assert tuple(x_p.size()) == (N, C, H, W) + + +def test_custom_auto_encoder(): +    num_class = 10 +    channels = 1 +    embedding_dims = (64, 64, 32) +    feature_channels = 32 +    ae = AutoEncoder(num_class=num_class, +                     channels=channels, +                     feature_channels=feature_channels, +                     embedding_dims=embedding_dims) +    x = torch.rand(N, 1, H, W) +    y = torch.randint(num_class, (N,)) + +    ae.train() +    ((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y) +    assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2) +    assert tuple(x_p.size()) == (N, 1, H, W) +    assert tuple(f_p_c1.size()) \ +           == tuple(f_p_c2.size()) \ +           == (N, embedding_dims[2]) +    assert tuple(xrecon.size()) == tuple(cano.size()) == () + +    ae.eval() +    (x_c, x_p) = ae(x, x, x) +    assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2) +    assert tuple(x_p.size()) == (N, 1, H, W) | 
