From f1fe77c083f952e81cf80c0b44611fc6057a7882 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 6 Jan 2021 22:19:27 +0800 Subject: Add CUDA support --- test/cuda.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 test/cuda.py (limited to 'test/cuda.py') diff --git a/test/cuda.py b/test/cuda.py new file mode 100644 index 0000000..ef0ea36 --- /dev/null +++ b/test/cuda.py @@ -0,0 +1,35 @@ +import torch + +from models import RGBPartNet + +P, K = 2, 4 +N, T, C, H, W = P * K, 10, 3, 64, 32 + + +def rand_x1_x2_y(n, t, c, h, w): + x1 = torch.rand(n, t, c, h, w) + x2 = torch.rand(n, t, c, h, w) + y = [] + for p in range(P): + y += [p] * K + y = torch.as_tensor(y) + return x1, x2, y + + +def test_default_rgb_part_net_cuda(): + rgb_pa = RGBPartNet() + rgb_pa = rgb_pa.cuda() + x1, x2, y = rand_x1_x2_y(N, T, C, H, W) + x1, x2, y = x1.cuda(), x2.cuda(), y.cuda() + + rgb_pa.train() + loss, metrics = rgb_pa(x1, x2, y) + _, _, _, _ = metrics + assert loss.device == torch.device('cuda', torch.cuda.current_device()) + assert tuple(loss.size()) == () + assert isinstance(_, float) + + rgb_pa.eval() + x = rgb_pa(x1, x2) + assert x.device == torch.device('cuda', torch.cuda.current_device()) + assert tuple(x.size()) == (23, N, 256) -- cgit v1.2.3