summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py
index 3842844..5dc7d97 100644
--- a/models/model.py
+++ b/models/model.py
@@ -75,6 +75,7 @@ class Model:
hp = self.hp.copy()
lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
+ self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
self.writer = SummaryWriter(self.log_name)
@@ -95,9 +96,10 @@ class Model:
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
- loss, metrics = self.rgb_pn(
- batch_c1['clip'], batch_c2['clip'], batch_c1['label']
- )
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ y = batch_c1['label'].to(self.device)
+ loss, metrics = self.rgb_pn(x_c1, x_c2, y)
loss.backward()
self.optimizer.step()
# Step scheduler
@@ -144,8 +146,8 @@ class Model:
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
- self.train_size = dataset_config['train_size']
- self.in_channels = dataset_config['num_input_channels']
+ self.train_size = dataset_config.get('train_size', 74)
+ self.in_channels = dataset_config.get('num_input_channels', 3)
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']