summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/auto_encoder.py7
-rw-r--r--models/model.py12
-rw-r--r--test/cuda.py35
3 files changed, 44 insertions, 10 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index eaac2fe..7c1f7ef 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -132,17 +132,14 @@ class AutoEncoder(nn.Module):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- f_a_size, f_c_size, f_p_size = (
- f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
- )
# Decode canonical features for HPM
x_c_c1_t2 = self.decoder(
- torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size),
+ torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2),
no_trans_conv=True
)
# Decode pose features for Part Net
x_p_c1_t2 = self.decoder(
- torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
+ torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2
)
if self.training:
diff --git a/models/model.py b/models/model.py
index 3842844..5dc7d97 100644
--- a/models/model.py
+++ b/models/model.py
@@ -75,6 +75,7 @@ class Model:
hp = self.hp.copy()
lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
+ self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
self.writer = SummaryWriter(self.log_name)
@@ -95,9 +96,10 @@ class Model:
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
- loss, metrics = self.rgb_pn(
- batch_c1['clip'], batch_c2['clip'], batch_c1['label']
- )
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ y = batch_c1['label'].to(self.device)
+ loss, metrics = self.rgb_pn(x_c1, x_c2, y)
loss.backward()
self.optimizer.step()
# Step scheduler
@@ -144,8 +146,8 @@ class Model:
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
- self.train_size = dataset_config['train_size']
- self.in_channels = dataset_config['num_input_channels']
+ self.train_size = dataset_config.get('train_size', 74)
+ self.in_channels = dataset_config.get('num_input_channels', 3)
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']
diff --git a/test/cuda.py b/test/cuda.py
new file mode 100644
index 0000000..ef0ea36
--- /dev/null
+++ b/test/cuda.py
@@ -0,0 +1,35 @@
+import torch
+
+from models import RGBPartNet
+
+P, K = 2, 4
+N, T, C, H, W = P * K, 10, 3, 64, 32
+
+
+def rand_x1_x2_y(n, t, c, h, w):
+ x1 = torch.rand(n, t, c, h, w)
+ x2 = torch.rand(n, t, c, h, w)
+ y = []
+ for p in range(P):
+ y += [p] * K
+ y = torch.as_tensor(y)
+ return x1, x2, y
+
+
+def test_default_rgb_part_net_cuda():
+ rgb_pa = RGBPartNet()
+ rgb_pa = rgb_pa.cuda()
+ x1, x2, y = rand_x1_x2_y(N, T, C, H, W)
+ x1, x2, y = x1.cuda(), x2.cuda(), y.cuda()
+
+ rgb_pa.train()
+ loss, metrics = rgb_pa(x1, x2, y)
+ _, _, _, _ = metrics
+ assert loss.device == torch.device('cuda', torch.cuda.current_device())
+ assert tuple(loss.size()) == ()
+ assert isinstance(_, float)
+
+ rgb_pa.eval()
+ x = rgb_pa(x1, x2)
+ assert x.device == torch.device('cuda', torch.cuda.current_device())
+ assert tuple(x.size()) == (23, N, 256)