diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-05 20:20:06 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-05 20:20:06 +0800 |
commit | ab29067d6469473481cc73fe42bcaf69d7633a83 (patch) | |
tree | b742b58603cce7dd5c150fbfaad3d195395bc6fa /test | |
parent | 7158a18a5f2789b2b2902c4b918ed80002970249 (diff) |
Implement Batch All Triplet Loss
Diffstat (limited to 'test')
-rw-r--r-- | test/rgb_part_net.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/test/rgb_part_net.py b/test/rgb_part_net.py new file mode 100644 index 0000000..1d754a0 --- /dev/null +++ b/test/rgb_part_net.py @@ -0,0 +1,65 @@ +import torch + +from models import RGBPartNet + +P, K = 2, 4 +N, T, C, H, W = P * K, 10, 3, 64, 32 + + +def rand_x1_x2_y(n, t, c, h, w): + x1 = torch.rand(n, t, c, h, w) + x2 = torch.rand(n, t, c, h, w) + y = [] + for p in range(P): + y += [p] * K + y = torch.as_tensor(y) + return x1, x2, y + + +def test_default_rgb_part_net(): + rgb_pa = RGBPartNet() + x1, x2, y = rand_x1_x2_y(N, T, C, H, W) + + rgb_pa.train() + loss, metrics = rgb_pa(x1, x2, y) + _, _, _, _ = metrics + assert tuple(loss.size()) == () + assert isinstance(_, float) + + rgb_pa.eval() + x = rgb_pa(x1, x2) + assert tuple(x.size()) == (23, N, 256) + + +def test_custom_rgb_part_net(): + hpm_scales = (1, 2, 4, 8) + tfa_num_parts = 8 + embedding_dims = 1024 + rgb_pa = RGBPartNet(num_class=10, + ae_in_channels=1, + ae_feature_channels=32, + f_a_c_p_dims=(64, 64, 32), + hpm_scales=hpm_scales, + hpm_use_avg_pool=True, + hpm_use_max_pool=False, + fpfe_feature_channels=64, + fpfe_kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), + fpfe_paddings=((2, 1), (1, 1), (1, 1), (1, 1)), + fpfe_halving=(1, 1, 3, 3), + tfa_squeeze_ratio=8, + tfa_num_parts=tfa_num_parts, + embedding_dims=1024, + triplet_margin=0.4) + x1, x2, y = rand_x1_x2_y(N, T, 1, H, W) + + rgb_pa.train() + loss, metrics = rgb_pa(x1, x2, y) + _, _, _, _ = metrics + assert tuple(loss.size()) == () + assert isinstance(_, float) + + rgb_pa.eval() + x = rgb_pa(x1, x2) + assert tuple(x.size()) == ( + sum(hpm_scales) + tfa_num_parts, N, embedding_dims + ) |