diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-13 10:59:59 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-13 10:59:59 +0800 |
commit | 8ee391b65e2b48d777a268749f54b3aa9e4b9142 (patch) | |
tree | 85bea16f0c95722ae3c3b37c5ebd0d50800432e6 /models | |
parent | abb6989683829f12bf43ff19580444bf3396ac44 (diff) |
Add multiple checkpoints for different model and set default config value
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 53 |
1 files changed, 34 insertions, 19 deletions
diff --git a/models/model.py b/models/model.py index 3cae788..7373dbb 100644 --- a/models/model.py +++ b/models/model.py @@ -27,7 +27,7 @@ class Model: model_config: ModelConfiguration, hyperparameter_config: HyperparameterConfiguration ): - self.disable_acc = system_config['disable_acc'] + self.disable_acc = system_config.get('disable_acc', False) if self.disable_acc: self.device = torch.device('cpu') else: # Enable accelerator @@ -37,17 +37,21 @@ class Model: print('No accelerator available, fallback to CPU.') self.device = torch.device('cpu') - self.save_dir = system_config['save_dir'] + self.save_dir = system_config.get('save_dir', 'runs') + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') self.log_dir = os.path.join(self.save_dir, 'logs') - for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir): + for dir_ in (self.log_dir, self.checkpoint_dir): if not os.path.exists(dir_): os.mkdir(dir_) self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta['restore_iter'] - self.total_iter = self.meta['total_iter'] + self.curr_iter = self.meta.get('restore_iter', 0) + self.total_iter = self.meta.get('total_iter', 80_000) + self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) + self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) self.is_train: bool = True self.train_size: int = 74 @@ -58,11 +62,9 @@ class Model: self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None - self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) + self._model_name: str = self.meta.get('name', 'RGB-GaitPart') self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' - self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None @@ -79,13 +81,26 @@ class Model: } @property - def _signature(self) -> str: - return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + def _model_sig(self) -> str: + return '_'.join((self._model_name, self.curr_iter, self.total_iter)) + + @property + def _checkpoint_sig(self) -> str: + return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, + str(self.pr), str(self.k))) @property def _checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self._signature) + return os.path.join(self.checkpoint_dir, self._checkpoint_sig) + + @property + def _log_sig(self) -> str: + return '_'.join((self._model_name, self.total_iter, self._hp_sig, + self._dataset_sig, str(self.pr), str(self.k))) + + @property + def _log_name(self) -> str: + return os.path.join(self.log_dir, self._log_sig) def fit_all( self, @@ -95,8 +110,12 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (condition, selector) in dataset_selectors.items(): + for (curr_iter, total_iter, (condition, selector)) in zip( + self.curr_iters, self.total_iters, dataset_selectors.items() + ): print(f'Training model {condition} ...') + self.curr_iter = curr_iter + self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), dataloader_config @@ -172,7 +191,6 @@ class Model: start_time = datetime.now() if self.curr_iter == self.total_iter: - self.curr_iter = 0 self.writer.close() break @@ -373,7 +391,6 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: dict = dataset_config.copy() name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': @@ -389,10 +406,8 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size') + (self.pr, self.k) = config.pop('batch_size', (8, 16)) if self.is_train: - self._log_name = '_'.join( - (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, @@ -424,7 +439,7 @@ class Model: _config = config.copy() if popped_keys: for key in popped_keys: - _config.pop(key) + _config.pop(key, None) return self._gen_sig(list(_config.values())) |