diff options
-rw-r--r-- | .idea/gait-recognition.iml | 2 | ||||
-rw-r--r-- | config.py | 14 | ||||
-rw-r--r-- | models/model.py | 55 | ||||
-rw-r--r-- | test/model.py | 17 |
4 files changed, 55 insertions, 33 deletions
diff --git a/.idea/gait-recognition.iml b/.idea/gait-recognition.iml index e175a98..79e2f3e 100644 --- a/.idea/gait-recognition.iml +++ b/.idea/gait-recognition.iml @@ -5,7 +5,7 @@ <excludeFolder url="file://$MODULE_DIR$/venv" /> <excludeFolder url="file://$MODULE_DIR$/data" /> </content> - <orderEntry type="inheritedJdk" /> + <orderEntry type="jdk" jdkName="Python 3.9 (gait-recognition)" jdkType="Python SDK" /> <orderEntry type="sourceFolder" forTests="false" /> </component> <component name="PyDocumentationSettings"> @@ -11,7 +11,7 @@ config = { 'dataset': { # Name of dataset (CASIA-B or FVG) 'name': 'CASIA-B', - # Path to dataset root + # Path to dataset root (required) 'root_dir': 'data/CASIA-B-MRCNN/SEG', # The number of subjects for training 'train_size': 74, @@ -86,9 +86,13 @@ config = { 'model': { # Model name, used for naming checkpoint 'name': 'RGB-GaitPart', - # Restoration iteration from checkpoint - 'restore_iter': 0, - # Total iteration for training - 'total_iter': 80000, + # Restoration iteration from checkpoint (single model) + # 'restore_iter': 0, + # Total iteration for training (single model) + # 'total_iter': 80000, + # Restoration iteration (multiple models, e.g. nm, bg and cl) + 'restore_iters': (0, 0, 0), + # Total iteration for training (multiple models) + 'total_iters': (80_000, 80_000, 80_000), }, } diff --git a/models/model.py b/models/model.py index b7bde9e..258fd71 100644 --- a/models/model.py +++ b/models/model.py @@ -24,7 +24,7 @@ class Model: model_config: Dict, hyperparameter_config: Dict ): - self.disable_acc = system_config['disable_acc'] + self.disable_acc = system_config.get('disable_acc', False) if self.disable_acc: self.device = torch.device('cpu') else: # Enable accelerator @@ -34,17 +34,21 @@ class Model: print('No accelerator available, fallback to CPU.') self.device = torch.device('cpu') - self.save_dir = system_config['save_dir'] + self.save_dir = system_config.get('save_dir', 'runs') + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') self.log_dir = os.path.join(self.save_dir, 'logs') - for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir): + for dir_ in (self.log_dir, self.checkpoint_dir): if not os.path.exists(dir_): os.mkdir(dir_) self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta['restore_iter'] - self.total_iter = self.meta['total_iter'] + self.curr_iter = self.meta.get('restore_iter', 0) + self.total_iter = self.meta.get('total_iter', 80_000) + self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) + self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) self.is_train: bool = True self.train_size: int = 74 @@ -55,11 +59,9 @@ class Model: self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None - self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) + self._model_name: str = self.meta.get('name', 'RGB-GaitPart') self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' - self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None @@ -76,13 +78,28 @@ class Model: } @property - def _signature(self) -> str: - return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, - self._dataset_sig, str(self.pr), str(self.k))) + def _model_sig(self) -> str: + return '_'.join( + (self._model_name, str(self.curr_iter), str(self.total_iter)) + ) + + @property + def _checkpoint_sig(self) -> str: + return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig, + str(self.pr), str(self.k))) @property def _checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self._signature) + return os.path.join(self.checkpoint_dir, self._checkpoint_sig) + + @property + def _log_sig(self) -> str: + return '_'.join((self._model_name, str(self.total_iter), self._hp_sig, + self._dataset_sig, str(self.pr), str(self.k))) + + @property + def _log_name(self) -> str: + return os.path.join(self.log_dir, self._log_sig) def fit_all( self, @@ -92,8 +109,12 @@ class Model: ], dataloader_config: Dict, ): - for (condition, selector) in dataset_selectors.items(): + for (curr_iter, total_iter, (condition, selector)) in zip( + self.curr_iters, self.total_iters, dataset_selectors.items() + ): print(f'Training model {condition} ...') + self.curr_iter = curr_iter + self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), dataloader_config @@ -169,7 +190,6 @@ class Model: start_time = datetime.now() if self.curr_iter == self.total_iter: - self.curr_iter = 0 self.writer.close() break @@ -370,7 +390,6 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: Dict = dataset_config.copy() name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': @@ -386,10 +405,8 @@ class Model: dataloader_config: Dict ) -> DataLoader: config: Dict = dataloader_config.copy() - (self.pr, self.k) = config.pop('batch_size') + (self.pr, self.k) = config.pop('batch_size', (8, 16)) if self.is_train: - self._log_name = '_'.join( - (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, @@ -421,7 +438,7 @@ class Model: _config = config.copy() if popped_keys: for key in popped_keys: - _config.pop(key) + _config.pop(key, None) return self._gen_sig(list(_config.values())) diff --git a/test/model.py b/test/model.py index 5d60475..f7fc57e 100644 --- a/test/model.py +++ b/test/model.py @@ -16,10 +16,10 @@ def test_default_signature(): 'runs', 'logs', 'RGB-GaitPart_80000_64_128_128_64_1_2_4_True_True_32_5_' '3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_' '0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_8_16') - assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_' - 'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_' - '256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B' - '_74_30_15_3_64_32_8_16') + assert model._checkpoint_sig == ('RGB-GaitPart_0_80000_64_128_128_64_1_2_4_' + 'True_True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_' + '3_4_16_256_0.2_0.0001_0.9_0.999_0.001_' + '500_0.9_CASIA-B_74_30_15_3_64_32_8_16') def test_default_signature_with_selector(): @@ -34,7 +34,8 @@ def test_default_signature_with_selector(): '3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_' '0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_bg-0\\d_' 'nm-0\\d_8_16') - assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_' - 'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_' - '256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B' - '_74_30_15_3_64_32_bg-0\\d_nm-0\\d_8_16') + assert model._checkpoint_sig == ('RGB-GaitPart_0_80000_64_128_128_64_1_2_4_' + 'True_True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_' + '3_4_16_256_0.2_0.0001_0.9_0.999_0.001_' + '500_0.9_CASIA-B_74_30_15_3_64_32_bg-0\\d_' + 'nm-0\\d_8_16') |