summaryrefslogtreecommitdiff
path: root/test/hpm.py
blob: a68337d003080ad57221cf3ae8d8c949a6728377 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch

from models import HorizontalPyramidMatching

T, N, C, H, W = 15, 4, 256, 32, 16


def test_default_hpm():
    hpm = HorizontalPyramidMatching(in_channels=C)
    x = torch.rand(T, N, C, H, W)
    x = hpm(x)
    assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)


def test_custom_hpm():
    hpm = HorizontalPyramidMatching(in_channels=2048,
                                    out_channels=256,
                                    scales=(1, 2, 4, 8),
                                    use_avg_pool=True,
                                    use_max_pool=False)
    x = torch.rand(T, N, 2048, H, W)
    x = hpm(x)
    assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)