summaryrefslogtreecommitdiff
path: root/test/part_net.py
blob: 25e92ae365058e0a97197e7839b3f570227b2763 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch

from models.part_net import FrameLevelPartFeatureExtractor, \
    TemporalFeatureAggregator, PartNet

T, N, C, H, W = 15, 4, 3, 64, 32


def test_default_fpfe():
    fpfe = FrameLevelPartFeatureExtractor()
    x = torch.rand(T, N, C, H, W)
    x = fpfe(x)

    assert tuple(x.size()) == (T * N, 32 * 4, 16, 8)


def test_custom_fpfe():
    feature_channels = 64
    fpfe = FrameLevelPartFeatureExtractor(
        in_channels=1,
        feature_channels=feature_channels,
        kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
        paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
        halving=(1, 1, 3, 3)
    )
    x = torch.rand(T, N, 1, H, W)
    x = fpfe(x)

    assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4)


def test_default_tfa():
    in_channels = 32 * 4
    tfa = TemporalFeatureAggregator(in_channels)
    x = torch.rand(16, T, N, in_channels)
    x = tfa(x)

    assert tuple(x.size()) == (16, N, in_channels)


def test_custom_tfa():
    in_channels = 64 * 8
    num_part = 8
    tfa = TemporalFeatureAggregator(in_channels=in_channels,
                                    squeeze_ratio=8, num_part=num_part)
    x = torch.rand(num_part, T, N, in_channels)
    x = tfa(x)

    assert tuple(x.size()) == (num_part, N, in_channels)


def test_default_part_net():
    pa = PartNet()
    x = torch.rand(T, N, C, H, W)
    x = pa(x)

    assert tuple(x.size()) == (16, N, 32 * 4)


def test_custom_part_net():
    feature_channels = 64
    pa = PartNet(in_channels=1, feature_channels=feature_channels,
                 kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
                 paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
                 halving=(1, 1, 3, 3),
                 squeeze_ratio=8,
                 num_part=8)
    x = torch.rand(T, N, 1, H, W)
    x = pa(x)

    assert tuple(x.size()) == (8, N, pa.tfa_in_channels)