diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/models/model.py b/models/model.py index b7bde9e..258fd71 100644 --- a/models/model.py +++ b/models/model.py @@ -24,7 +24,7 @@ class Model: model_config: Dict, hyperparameter_config: Dict ): - self.disable_acc = system_config['disable_acc'] + self.disable_acc = system_config.get('disable_acc', False) if self.disable_acc: self.device = torch.device('cpu') else: # Enable accelerator @@ -34,17 +34,21 @@ class Model: print('No accelerator available, fallback to CPU.') self.device = torch.device('cpu') - self.save_dir = system_config['save_dir'] + self.save_dir = system_config.get('save_dir', 'runs') + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') self.log_dir = os.path.join(self.save_dir, 'logs') - for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir): + for dir_ in (self.log_dir, self.checkpoint_dir): if not os.path.exists(dir_): os.mkdir(dir_) self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta['restore_iter'] - self.total_iter = self.meta['total_iter'] + self.curr_iter = self.meta.get('restore_iter', 0) + self.total_iter = self.meta.get('total_iter', 80_000) + self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) + self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) self.is_train: bool = True self.train_size: int = 74 @@ -55,11 +59,9 @@ class Model: self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None - self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) + self._model_name: str = self.meta.get('name', 'RGB-GaitPart') self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' - self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None @@ -76,13 +78,28 @@ class Model: } @property - def _signature(self) -> str: - return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + def _model_sig(self) -> str: + return '_'.join( + (self._model_name, str(self.curr_iter), str(self.total_iter)) + ) + + @property + def _checkpoint_sig(self) -> str: + return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, + str(self.pr), str(self.k))) @property def _checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self._signature) + return os.path.join(self.checkpoint_dir, self._checkpoint_sig) + + @property + def _log_sig(self) -> str: + return '_'.join((self._model_name, str(self.total_iter), self._hp_sig, + self._dataset_sig, str(self.pr), str(self.k))) + + @property + def _log_name(self) -> str: + return os.path.join(self.log_dir, self._log_sig) def fit_all( self, @@ -92,8 +109,12 @@ class Model: ], dataloader_config: Dict, ): - for (condition, selector) in dataset_selectors.items(): + for (curr_iter, total_iter, (condition, selector)) in zip( + self.curr_iters, self.total_iters, dataset_selectors.items() + ): print(f'Training model {condition} ...') + self.curr_iter = curr_iter + self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), dataloader_config @@ -169,7 +190,6 @@ class Model: start_time = datetime.now() if self.curr_iter == self.total_iter: - self.curr_iter = 0 self.writer.close() break @@ -370,7 +390,6 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: Dict = dataset_config.copy() name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': @@ -386,10 +405,8 @@ class Model: dataloader_config: Dict ) -> DataLoader: config: Dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size') + (self.pr, self.k) = config.pop('batch_size', (8, 16)) if self.is_train: - self._log_name = '_'.join( - (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, @@ -421,7 +438,7 @@ class Model: _config = config.copy() if popped_keys: for key in popped_keys: - _config.pop(key) + _config.pop(key, None) return self._gen_sig(list(_config.values())) |