1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
import torch
from models.hpm import HorizontalPyramidMatching
T, N, C, H, W = 15, 4, 256, 32, 16
def test_default_hpm():
hpm = HorizontalPyramidMatching(in_channels=C)
x = torch.rand(T, N, C, H, W)
x = hpm(x)
assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
def test_custom_hpm():
hpm = HorizontalPyramidMatching(in_channels=2048,
out_channels=256,
scales=(1, 2, 4, 8),
use_avg_pool=True,
use_max_pool=False)
x = torch.rand(T, N, 2048, H, W)
x = hpm(x)
assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)
|