diff options
-rw-r--r-- | config.py | 8 | ||||
-rw-r--r-- | models/model.py | 17 | ||||
-rw-r--r-- | utils/configuration.py | 4 |
3 files changed, 19 insertions, 10 deletions
@@ -1,12 +1,10 @@ -import torch - from utils.configuration import Configuration config: Configuration = { 'system': { - # Device(s) used in training and testing (CPU or CUDA) - 'device': torch.device('cuda'), - # GPU(s) used in training or testing, if CUDA enabled + # Disable accelerator + 'disable_acc': False, + # GPU(s) used in training or testing if available 'CUDA_VISIBLE_DEVICES': '0', # Directory used in training or testing for temporary storage 'save_dir': 'runs', diff --git a/models/model.py b/models/model.py index 5dc7d97..bf8b5fb 100644 --- a/models/model.py +++ b/models/model.py @@ -24,7 +24,16 @@ class Model: model_config: ModelConfiguration, hyperparameter_config: HyperparameterConfiguration ): - self.device = system_config['device'] + self.disable_acc = system_config['disable_acc'] + if self.disable_acc: + self.device = torch.device('cpu') + else: # Enable accelerator + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + print('No accelerator available, fallback to CPU.') + self.device = torch.device('cpu') + self.save_dir = system_config['save_dir'] self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') self.log_dir = os.path.join(self.save_dir, 'logs') @@ -75,11 +84,15 @@ class Model: hp = self.hp.copy() lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999)) self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp) - self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) self.writer = SummaryWriter(self.log_name) + if not self.disable_acc: + if torch.cuda.device_count() > 1: + self.rgb_pn = nn.DataParallel(self.rgb_pn) + self.rgb_pn = self.rgb_pn.to(self.device) + self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: diff --git a/utils/configuration.py b/utils/configuration.py index 3e98343..f3ae0b3 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -1,12 +1,10 @@ from typing import TypedDict, Optional, Union -import torch - from utils.dataset import ClipClasses, ClipConditions, ClipViews class SystemConfiguration(TypedDict): - device: torch.device + disable_acc: bool CUDA_VISIBLE_DEVICES: str save_dir: str |