summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-19 22:39:49 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-19 22:39:49 +0800
commitd12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (patch)
tree71b5209ce4b5cfb1d09b89fe133028bbfa481dc9 /test
parent4aa9044122878a8e2b887a8b170c036983431559 (diff)
New branch with auto-encoder only
Diffstat (limited to 'test')
-rw-r--r--test/hpm.py23
-rw-r--r--test/part_net.py71
2 files changed, 0 insertions, 94 deletions
diff --git a/test/hpm.py b/test/hpm.py
deleted file mode 100644
index 0aefbb8..0000000
--- a/test/hpm.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import torch
-
-from models.hpm import HorizontalPyramidMatching
-
-T, N, C, H, W = 15, 4, 256, 32, 16
-
-
-def test_default_hpm():
- hpm = HorizontalPyramidMatching(in_channels=C)
- x = torch.rand(T, N, C, H, W)
- x = hpm(x)
- assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
-
-
-def test_custom_hpm():
- hpm = HorizontalPyramidMatching(in_channels=2048,
- out_channels=256,
- scales=(1, 2, 4, 8),
- use_avg_pool=True,
- use_max_pool=False)
- x = torch.rand(T, N, 2048, H, W)
- x = hpm(x)
- assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)
diff --git a/test/part_net.py b/test/part_net.py
deleted file mode 100644
index 25e92ae..0000000
--- a/test/part_net.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-
-from models.part_net import FrameLevelPartFeatureExtractor, \
- TemporalFeatureAggregator, PartNet
-
-T, N, C, H, W = 15, 4, 3, 64, 32
-
-
-def test_default_fpfe():
- fpfe = FrameLevelPartFeatureExtractor()
- x = torch.rand(T, N, C, H, W)
- x = fpfe(x)
-
- assert tuple(x.size()) == (T * N, 32 * 4, 16, 8)
-
-
-def test_custom_fpfe():
- feature_channels = 64
- fpfe = FrameLevelPartFeatureExtractor(
- in_channels=1,
- feature_channels=feature_channels,
- kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
- paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
- halving=(1, 1, 3, 3)
- )
- x = torch.rand(T, N, 1, H, W)
- x = fpfe(x)
-
- assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4)
-
-
-def test_default_tfa():
- in_channels = 32 * 4
- tfa = TemporalFeatureAggregator(in_channels)
- x = torch.rand(16, T, N, in_channels)
- x = tfa(x)
-
- assert tuple(x.size()) == (16, N, in_channels)
-
-
-def test_custom_tfa():
- in_channels = 64 * 8
- num_part = 8
- tfa = TemporalFeatureAggregator(in_channels=in_channels,
- squeeze_ratio=8, num_part=num_part)
- x = torch.rand(num_part, T, N, in_channels)
- x = tfa(x)
-
- assert tuple(x.size()) == (num_part, N, in_channels)
-
-
-def test_default_part_net():
- pa = PartNet()
- x = torch.rand(T, N, C, H, W)
- x = pa(x)
-
- assert tuple(x.size()) == (16, N, 32 * 4)
-
-
-def test_custom_part_net():
- feature_channels = 64
- pa = PartNet(in_channels=1, feature_channels=feature_channels,
- kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
- paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
- halving=(1, 1, 3, 3),
- squeeze_ratio=8,
- num_part=8)
- x = torch.rand(T, N, 1, H, W)
- x = pa(x)
-
- assert tuple(x.size()) == (8, N, pa.tfa_in_channels)