1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
import torch
from models.part_net import FrameLevelPartFeatureExtractor, \
TemporalFeatureAggregator, PartNet
T, N, C, H, W = 15, 4, 3, 64, 32
def test_default_fpfe():
fpfe = FrameLevelPartFeatureExtractor()
x = torch.rand(T, N, C, H, W)
x = fpfe(x)
assert tuple(x.size()) == (T * N, 32 * 4, 16, 8)
def test_custom_fpfe():
feature_channels = 64
fpfe = FrameLevelPartFeatureExtractor(
in_channels=1,
feature_channels=feature_channels,
kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
halving=(1, 1, 3, 3)
)
x = torch.rand(T, N, 1, H, W)
x = fpfe(x)
assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4)
def test_default_tfa():
in_channels = 32 * 4
tfa = TemporalFeatureAggregator(in_channels)
x = torch.rand(16, T, N, in_channels)
x = tfa(x)
assert tuple(x.size()) == (16, N, in_channels)
def test_custom_tfa():
in_channels = 64 * 8
num_part = 8
tfa = TemporalFeatureAggregator(in_channels=in_channels,
squeeze_ratio=8, num_part=num_part)
x = torch.rand(num_part, T, N, in_channels)
x = tfa(x)
assert tuple(x.size()) == (num_part, N, in_channels)
def test_default_part_net():
pa = PartNet()
x = torch.rand(T, N, C, H, W)
x = pa(x)
assert tuple(x.size()) == (16, N, 32 * 4)
def test_custom_part_net():
feature_channels = 64
pa = PartNet(in_channels=1, feature_channels=feature_channels,
kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
halving=(1, 1, 3, 3),
squeeze_ratio=8,
num_part=8)
x = torch.rand(T, N, 1, H, W)
x = pa(x)
assert tuple(x.size()) == (8, N, pa.tfa_in_channels)
|