summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-03 15:08:19 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-03 15:08:19 +0800
commit2ac1787e4580521848460215e6b06f4bb1648f06 (patch)
tree75e4121a89b38f69c600711bac9e3734294f7d83
parent2e6a6d5bda3ddea10afda8e07d2cfe5697a26de3 (diff)
Unit testing on auto-encoder, HPM and Part Net
-rw-r--r--test/auto_encoder.py99
-rw-r--r--test/hpm.py23
-rw-r--r--test/part_net.py71
3 files changed, 193 insertions, 0 deletions
diff --git a/test/auto_encoder.py b/test/auto_encoder.py
new file mode 100644
index 0000000..5cefb8e
--- /dev/null
+++ b/test/auto_encoder.py
@@ -0,0 +1,99 @@
+import torch
+
+from models.auto_encoder import Encoder, Decoder, AutoEncoder
+
+N, C, H, W = 128, 3, 64, 32
+
+
+def test_default_encoder():
+ encoder = Encoder()
+ x = torch.rand(N, C, H, W)
+ f_a, f_c, f_p = encoder(x)
+
+ assert tuple(f_a.size()) == (N, 128)
+ assert tuple(f_c.size()) == (N, 128)
+ assert tuple(f_p.size()) == (N, 64)
+
+
+def test_custom_encoder():
+ output_dims = (64, 64, 32)
+ encoder = Encoder(in_channels=1,
+ feature_channels=32,
+ output_dims=output_dims)
+ x = torch.rand(N, 1, H, W)
+ f_a, f_c, f_p = encoder(x)
+
+ assert tuple(f_a.size()) == (N, output_dims[0])
+ assert tuple(f_c.size()) == (N, output_dims[1])
+ assert tuple(f_p.size()) == (N, output_dims[2])
+
+
+def test_default_decoder():
+ decoder = Decoder()
+ f_a, f_c, f_p = torch.rand(N, 128), torch.rand(N, 128), torch.rand(N, 64)
+
+ x_trans_conv = decoder(f_a, f_c, f_p)
+ assert tuple(x_trans_conv.size()) == (N, C, H, W)
+ x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True)
+ assert tuple(x_no_trans_conv.size()) == (N, 64 * 8, 4, 2)
+
+
+def test_custom_decoder():
+ embedding_dims = (64, 64, 32)
+ feature_channels = 32
+ decoder = Decoder(input_dims=embedding_dims,
+ feature_channels=feature_channels,
+ out_channels=1)
+ f_a, f_c, f_p = (torch.rand(N, embedding_dims[0]),
+ torch.rand(N, embedding_dims[1]),
+ torch.rand(N, embedding_dims[2]))
+
+ x_trans_conv = decoder(f_a, f_c, f_p)
+ assert tuple(x_trans_conv.size()) == (N, 1, H, W)
+ x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True)
+ assert tuple(x_no_trans_conv.size()) == (N, feature_channels * 8, 4, 2)
+
+
+def test_default_auto_encoder():
+ ae = AutoEncoder()
+ x = torch.rand(N, C, H, W)
+ y = torch.randint(74, (N,))
+
+ ae.train()
+ ((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y)
+ assert tuple(x_c.size()) == (N, 64 * 8, 4, 2)
+ assert tuple(x_p.size()) == (N, C, H, W)
+ assert tuple(f_p_c1.size()) == tuple(f_p_c2.size()) == (N, 64)
+ assert tuple(xrecon.size()) == tuple(cano.size()) == ()
+
+ ae.eval()
+ (x_c, x_p) = ae(x, x, x)
+ assert tuple(x_c.size()) == (N, 64 * 8, 4, 2)
+ assert tuple(x_p.size()) == (N, C, H, W)
+
+
+def test_custom_auto_encoder():
+ num_class = 10
+ channels = 1
+ embedding_dims = (64, 64, 32)
+ feature_channels = 32
+ ae = AutoEncoder(num_class=num_class,
+ channels=channels,
+ feature_channels=feature_channels,
+ embedding_dims=embedding_dims)
+ x = torch.rand(N, 1, H, W)
+ y = torch.randint(num_class, (N,))
+
+ ae.train()
+ ((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y)
+ assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2)
+ assert tuple(x_p.size()) == (N, 1, H, W)
+ assert tuple(f_p_c1.size()) \
+ == tuple(f_p_c2.size()) \
+ == (N, embedding_dims[2])
+ assert tuple(xrecon.size()) == tuple(cano.size()) == ()
+
+ ae.eval()
+ (x_c, x_p) = ae(x, x, x)
+ assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2)
+ assert tuple(x_p.size()) == (N, 1, H, W)
diff --git a/test/hpm.py b/test/hpm.py
new file mode 100644
index 0000000..a68337d
--- /dev/null
+++ b/test/hpm.py
@@ -0,0 +1,23 @@
+import torch
+
+from models import HorizontalPyramidMatching
+
+T, N, C, H, W = 15, 4, 256, 32, 16
+
+
+def test_default_hpm():
+ hpm = HorizontalPyramidMatching(in_channels=C)
+ x = torch.rand(T, N, C, H, W)
+ x = hpm(x)
+ assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
+
+
+def test_custom_hpm():
+ hpm = HorizontalPyramidMatching(in_channels=2048,
+ out_channels=256,
+ scales=(1, 2, 4, 8),
+ use_avg_pool=True,
+ use_max_pool=False)
+ x = torch.rand(T, N, 2048, H, W)
+ x = hpm(x)
+ assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)
diff --git a/test/part_net.py b/test/part_net.py
new file mode 100644
index 0000000..25e92ae
--- /dev/null
+++ b/test/part_net.py
@@ -0,0 +1,71 @@
+import torch
+
+from models.part_net import FrameLevelPartFeatureExtractor, \
+ TemporalFeatureAggregator, PartNet
+
+T, N, C, H, W = 15, 4, 3, 64, 32
+
+
+def test_default_fpfe():
+ fpfe = FrameLevelPartFeatureExtractor()
+ x = torch.rand(T, N, C, H, W)
+ x = fpfe(x)
+
+ assert tuple(x.size()) == (T * N, 32 * 4, 16, 8)
+
+
+def test_custom_fpfe():
+ feature_channels = 64
+ fpfe = FrameLevelPartFeatureExtractor(
+ in_channels=1,
+ feature_channels=feature_channels,
+ kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
+ paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
+ halving=(1, 1, 3, 3)
+ )
+ x = torch.rand(T, N, 1, H, W)
+ x = fpfe(x)
+
+ assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4)
+
+
+def test_default_tfa():
+ in_channels = 32 * 4
+ tfa = TemporalFeatureAggregator(in_channels)
+ x = torch.rand(16, T, N, in_channels)
+ x = tfa(x)
+
+ assert tuple(x.size()) == (16, N, in_channels)
+
+
+def test_custom_tfa():
+ in_channels = 64 * 8
+ num_part = 8
+ tfa = TemporalFeatureAggregator(in_channels=in_channels,
+ squeeze_ratio=8, num_part=num_part)
+ x = torch.rand(num_part, T, N, in_channels)
+ x = tfa(x)
+
+ assert tuple(x.size()) == (num_part, N, in_channels)
+
+
+def test_default_part_net():
+ pa = PartNet()
+ x = torch.rand(T, N, C, H, W)
+ x = pa(x)
+
+ assert tuple(x.size()) == (16, N, 32 * 4)
+
+
+def test_custom_part_net():
+ feature_channels = 64
+ pa = PartNet(in_channels=1, feature_channels=feature_channels,
+ kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
+ paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
+ halving=(1, 1, 3, 3),
+ squeeze_ratio=8,
+ num_part=8)
+ x = torch.rand(T, N, 1, H, W)
+ x = pa(x)
+
+ assert tuple(x.size()) == (8, N, pa.tfa_in_channels)