summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/models/model.py b/models/model.py
index 5dc7d97..bf8b5fb 100644
--- a/models/model.py
+++ b/models/model.py
@@ -24,7 +24,16 @@ class Model:
model_config: ModelConfiguration,
hyperparameter_config: HyperparameterConfiguration
):
- self.device = system_config['device']
+ self.disable_acc = system_config['disable_acc']
+ if self.disable_acc:
+ self.device = torch.device('cpu')
+ else: # Enable accelerator
+ if torch.cuda.is_available():
+ self.device = torch.device('cuda')
+ else:
+ print('No accelerator available, fallback to CPU.')
+ self.device = torch.device('cpu')
+
self.save_dir = system_config['save_dir']
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
self.log_dir = os.path.join(self.save_dir, 'logs')
@@ -75,11 +84,15 @@ class Model:
hp = self.hp.copy()
lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
- self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
self.writer = SummaryWriter(self.log_name)
+ if not self.disable_acc:
+ if torch.cuda.device_count() > 1:
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
+ self.rgb_pn = self.rgb_pn.to(self.device)
+
self.rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0: