From 8ee391b65e2b48d777a268749f54b3aa9e4b9142 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 13 Jan 2021 10:59:59 +0800 Subject: Add multiple checkpoints for different model and set default config value --- .idea/gait-recognition.iml | 2 +- config.py | 14 +++++++----- models/model.py | 53 +++++++++++++++++++++++++++++----------------- test/model.py | 4 ++-- utils/configuration.py | 2 ++ 5 files changed, 48 insertions(+), 27 deletions(-) diff --git a/.idea/gait-recognition.iml b/.idea/gait-recognition.iml index e175a98..79e2f3e 100644 --- a/.idea/gait-recognition.iml +++ b/.idea/gait-recognition.iml @@ -5,7 +5,7 @@ - + diff --git a/config.py b/config.py index 8a8d93a..7517097 100644 --- a/config.py +++ b/config.py @@ -13,7 +13,7 @@ config: Configuration = { 'dataset': { # Name of dataset (CASIA-B or FVG) 'name': 'CASIA-B', - # Path to dataset root + # Path to dataset root (required) 'root_dir': 'data/CASIA-B-MRCNN/SEG', # The number of subjects for training 'train_size': 74, @@ -88,9 +88,13 @@ config: Configuration = { 'model': { # Model name, used for naming checkpoint 'name': 'RGB-GaitPart', - # Restoration iteration from checkpoint - 'restore_iter': 0, - # Total iteration for training - 'total_iter': 80000, + # Restoration iteration from checkpoint (single model) + # 'restore_iter': 0, + # Total iteration for training (single model) + # 'total_iter': 80000, + # Restoration iteration (multiple models, e.g. nm, bg and cl) + 'restore_iters': (0, 0, 0), + # Total iteration for training (multiple models) + 'total_iter': (80_000, 80_000, 80_000), }, } diff --git a/models/model.py b/models/model.py index 3cae788..7373dbb 100644 --- a/models/model.py +++ b/models/model.py @@ -27,7 +27,7 @@ class Model: model_config: ModelConfiguration, hyperparameter_config: HyperparameterConfiguration ): - self.disable_acc = system_config['disable_acc'] + self.disable_acc = system_config.get('disable_acc', False) if self.disable_acc: self.device = torch.device('cpu') else: # Enable accelerator @@ -37,17 +37,21 @@ class Model: print('No accelerator available, fallback to CPU.') self.device = torch.device('cpu') - self.save_dir = system_config['save_dir'] + self.save_dir = system_config.get('save_dir', 'runs') + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') self.log_dir = os.path.join(self.save_dir, 'logs') - for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir): + for dir_ in (self.log_dir, self.checkpoint_dir): if not os.path.exists(dir_): os.mkdir(dir_) self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta['restore_iter'] - self.total_iter = self.meta['total_iter'] + self.curr_iter = self.meta.get('restore_iter', 0) + self.total_iter = self.meta.get('total_iter', 80_000) + self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) + self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) self.is_train: bool = True self.train_size: int = 74 @@ -58,11 +62,9 @@ class Model: self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None - self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) + self._model_name: str = self.meta.get('name', 'RGB-GaitPart') self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' - self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None @@ -79,13 +81,26 @@ class Model: } @property - def _signature(self) -> str: - return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + def _model_sig(self) -> str: + return '_'.join((self._model_name, self.curr_iter, self.total_iter)) + + @property + def _checkpoint_sig(self) -> str: + return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, + str(self.pr), str(self.k))) @property def _checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self._signature) + return os.path.join(self.checkpoint_dir, self._checkpoint_sig) + + @property + def _log_sig(self) -> str: + return '_'.join((self._model_name, self.total_iter, self._hp_sig, + self._dataset_sig, str(self.pr), str(self.k))) + + @property + def _log_name(self) -> str: + return os.path.join(self.log_dir, self._log_sig) def fit_all( self, @@ -95,8 +110,12 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (condition, selector) in dataset_selectors.items(): + for (curr_iter, total_iter, (condition, selector)) in zip( + self.curr_iters, self.total_iters, dataset_selectors.items() + ): print(f'Training model {condition} ...') + self.curr_iter = curr_iter + self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), dataloader_config @@ -172,7 +191,6 @@ class Model: start_time = datetime.now() if self.curr_iter == self.total_iter: - self.curr_iter = 0 self.writer.close() break @@ -373,7 +391,6 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: dict = dataset_config.copy() name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': @@ -389,10 +406,8 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size') + (self.pr, self.k) = config.pop('batch_size', (8, 16)) if self.is_train: - self._log_name = '_'.join( - (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, @@ -424,7 +439,7 @@ class Model: _config = config.copy() if popped_keys: for key in popped_keys: - _config.pop(key) + _config.pop(key, None) return self._gen_sig(list(_config.values())) diff --git a/test/model.py b/test/model.py index 5d60475..8122989 100644 --- a/test/model.py +++ b/test/model.py @@ -16,7 +16,7 @@ def test_default_signature(): 'runs', 'logs', 'RGB-GaitPart_80000_64_128_128_64_1_2_4_True_True_32_5_' '3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_' '0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_8_16') - assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_' + assert model._checkpoint_sig == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_' 'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_' '256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B' '_74_30_15_3_64_32_8_16') @@ -34,7 +34,7 @@ def test_default_signature_with_selector(): '3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_' '0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_bg-0\\d_' 'nm-0\\d_8_16') - assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_' + assert model._checkpoint_sig == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_' 'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_' '256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B' '_74_30_15_3_64_32_bg-0\\d_nm-0\\d_8_16') diff --git a/utils/configuration.py b/utils/configuration.py index f1b5d5a..9a8c2ae 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -64,6 +64,8 @@ class ModelConfiguration(TypedDict): name: str restore_iter: int total_iter: int + restore_iters: tuple[int, ...] + total_iters: tuple[int, ...] class Configuration(TypedDict): -- cgit v1.2.3