summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-13 11:15:49 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-13 11:15:49 +0800
commit6de814e06664cd51474232823930faba9ebb6d3d (patch)
tree9f8d9e17919a98695844700fdc5bf3ec680b1a80
parentcad72ca5f1f3f7fc37ce3429af0bf255fed453fb (diff)
parent2ed37d21aac148dcec068590decb8eaff892d23b (diff)
Merge branch 'python3.8' into python3.7
# Conflicts: # utils/configuration.py
-rw-r--r--.idea/gait-recognition.iml2
-rw-r--r--config.py14
-rw-r--r--models/model.py55
-rw-r--r--test/model.py17
4 files changed, 55 insertions, 33 deletions
diff --git a/.idea/gait-recognition.iml b/.idea/gait-recognition.iml
index e175a98..79e2f3e 100644
--- a/.idea/gait-recognition.iml
+++ b/.idea/gait-recognition.iml
@@ -5,7 +5,7 @@
<excludeFolder url="file://$MODULE_DIR$/venv" />
<excludeFolder url="file://$MODULE_DIR$/data" />
</content>
- <orderEntry type="inheritedJdk" />
+ <orderEntry type="jdk" jdkName="Python 3.9 (gait-recognition)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
diff --git a/config.py b/config.py
index cd36cc5..f8e3711 100644
--- a/config.py
+++ b/config.py
@@ -11,7 +11,7 @@ config = {
'dataset': {
# Name of dataset (CASIA-B or FVG)
'name': 'CASIA-B',
- # Path to dataset root
+ # Path to dataset root (required)
'root_dir': 'data/CASIA-B-MRCNN/SEG',
# The number of subjects for training
'train_size': 74,
@@ -86,9 +86,13 @@ config = {
'model': {
# Model name, used for naming checkpoint
'name': 'RGB-GaitPart',
- # Restoration iteration from checkpoint
- 'restore_iter': 0,
- # Total iteration for training
- 'total_iter': 80000,
+ # Restoration iteration from checkpoint (single model)
+ # 'restore_iter': 0,
+ # Total iteration for training (single model)
+ # 'total_iter': 80000,
+ # Restoration iteration (multiple models, e.g. nm, bg and cl)
+ 'restore_iters': (0, 0, 0),
+ # Total iteration for training (multiple models)
+ 'total_iters': (80_000, 80_000, 80_000),
},
}
diff --git a/models/model.py b/models/model.py
index b7bde9e..258fd71 100644
--- a/models/model.py
+++ b/models/model.py
@@ -24,7 +24,7 @@ class Model:
model_config: Dict,
hyperparameter_config: Dict
):
- self.disable_acc = system_config['disable_acc']
+ self.disable_acc = system_config.get('disable_acc', False)
if self.disable_acc:
self.device = torch.device('cpu')
else: # Enable accelerator
@@ -34,17 +34,21 @@ class Model:
print('No accelerator available, fallback to CPU.')
self.device = torch.device('cpu')
- self.save_dir = system_config['save_dir']
+ self.save_dir = system_config.get('save_dir', 'runs')
+ if not os.path.exists(self.save_dir):
+ os.makedirs(self.save_dir)
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
self.log_dir = os.path.join(self.save_dir, 'logs')
- for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir):
+ for dir_ in (self.log_dir, self.checkpoint_dir):
if not os.path.exists(dir_):
os.mkdir(dir_)
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta['restore_iter']
- self.total_iter = self.meta['total_iter']
+ self.curr_iter = self.meta.get('restore_iter', 0)
+ self.total_iter = self.meta.get('total_iter', 80_000)
+ self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
+ self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
self.is_train: bool = True
self.train_size: int = 74
@@ -55,11 +59,9 @@ class Model:
self._gallery_dataset_meta: Optional[Dict[str, List]] = None
self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None
- self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
+ self._model_name: str = self.meta.get('name', 'RGB-GaitPart')
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
- self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
- self._log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
@@ -76,13 +78,28 @@ class Model:
}
@property
- def _signature(self) -> str:
- return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
- self._dataset_sig, str(self.pr), str(self.k)))
+ def _model_sig(self) -> str:
+ return '_'.join(
+ (self._model_name, str(self.curr_iter), str(self.total_iter))
+ )
+
+ @property
+ def _checkpoint_sig(self) -> str:
+ return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig,
+ str(self.pr), str(self.k)))
@property
def _checkpoint_name(self) -> str:
- return os.path.join(self.checkpoint_dir, self._signature)
+ return os.path.join(self.checkpoint_dir, self._checkpoint_sig)
+
+ @property
+ def _log_sig(self) -> str:
+ return '_'.join((self._model_name, str(self.total_iter), self._hp_sig,
+ self._dataset_sig, str(self.pr), str(self.k)))
+
+ @property
+ def _log_name(self) -> str:
+ return os.path.join(self.log_dir, self._log_sig)
def fit_all(
self,
@@ -92,8 +109,12 @@ class Model:
],
dataloader_config: Dict,
):
- for (condition, selector) in dataset_selectors.items():
+ for (curr_iter, total_iter, (condition, selector)) in zip(
+ self.curr_iters, self.total_iters, dataset_selectors.items()
+ ):
print(f'Training model {condition} ...')
+ self.curr_iter = curr_iter
+ self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
dataloader_config
@@ -169,7 +190,6 @@ class Model:
start_time = datetime.now()
if self.curr_iter == self.total_iter:
- self.curr_iter = 0
self.writer.close()
break
@@ -370,7 +390,6 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
- self._log_name = '_'.join((self._log_name, self._dataset_sig))
config: Dict = dataset_config.copy()
name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
@@ -386,10 +405,8 @@ class Model:
dataloader_config: Dict
) -> DataLoader:
config: Dict = dataloader_config.copy()
- (self.pr, self.k) = config.pop('batch_size')
+ (self.pr, self.k) = config.pop('batch_size', (8, 16))
if self.is_train:
- self._log_name = '_'.join(
- (self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
@@ -421,7 +438,7 @@ class Model:
_config = config.copy()
if popped_keys:
for key in popped_keys:
- _config.pop(key)
+ _config.pop(key, None)
return self._gen_sig(list(_config.values()))
diff --git a/test/model.py b/test/model.py
index 5d60475..f7fc57e 100644
--- a/test/model.py
+++ b/test/model.py
@@ -16,10 +16,10 @@ def test_default_signature():
'runs', 'logs', 'RGB-GaitPart_80000_64_128_128_64_1_2_4_True_True_32_5_'
'3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_'
'0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_8_16')
- assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
- 'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_'
- '256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B'
- '_74_30_15_3_64_32_8_16')
+ assert model._checkpoint_sig == ('RGB-GaitPart_0_80000_64_128_128_64_1_2_4_'
+ 'True_True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_'
+ '3_4_16_256_0.2_0.0001_0.9_0.999_0.001_'
+ '500_0.9_CASIA-B_74_30_15_3_64_32_8_16')
def test_default_signature_with_selector():
@@ -34,7 +34,8 @@ def test_default_signature_with_selector():
'3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_'
'0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_bg-0\\d_'
'nm-0\\d_8_16')
- assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
- 'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_'
- '256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B'
- '_74_30_15_3_64_32_bg-0\\d_nm-0\\d_8_16')
+ assert model._checkpoint_sig == ('RGB-GaitPart_0_80000_64_128_128_64_1_2_4_'
+ 'True_True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_'
+ '3_4_16_256_0.2_0.0001_0.9_0.999_0.001_'
+ '500_0.9_CASIA-B_74_30_15_3_64_32_bg-0\\d_'
+ 'nm-0\\d_8_16')