summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py7
-rw-r--r--models/model.py12
2 files changed, 9 insertions, 10 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index eaac2fe..7c1f7ef 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -132,17 +132,14 @@ class AutoEncoder(nn.Module):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- f_a_size, f_c_size, f_p_size = (
- f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
- )
# Decode canonical features for HPM
x_c_c1_t2 = self.decoder(
- torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size),
+ torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2),
no_trans_conv=True
)
# Decode pose features for Part Net
x_p_c1_t2 = self.decoder(
- torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
+ torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2
)
if self.training:
diff --git a/models/model.py b/models/model.py
index 3842844..5dc7d97 100644
--- a/models/model.py
+++ b/models/model.py
@@ -75,6 +75,7 @@ class Model:
hp = self.hp.copy()
lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
+ self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
self.writer = SummaryWriter(self.log_name)
@@ -95,9 +96,10 @@ class Model:
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
- loss, metrics = self.rgb_pn(
- batch_c1['clip'], batch_c2['clip'], batch_c1['label']
- )
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ y = batch_c1['label'].to(self.device)
+ loss, metrics = self.rgb_pn(x_c1, x_c2, y)
loss.backward()
self.optimizer.step()
# Step scheduler
@@ -144,8 +146,8 @@ class Model:
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
- self.train_size = dataset_config['train_size']
- self.in_channels = dataset_config['num_input_channels']
+ self.train_size = dataset_config.get('train_size', 74)
+ self.in_channels = dataset_config.get('num_input_channels', 3)
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']