summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.idea/gait-recognition.iml2
-rw-r--r--config.py14
-rw-r--r--models/model.py53
-rw-r--r--test/model.py4
-rw-r--r--utils/configuration.py2
5 files changed, 48 insertions, 27 deletions
diff --git a/.idea/gait-recognition.iml b/.idea/gait-recognition.iml
index e175a98..79e2f3e 100644
--- a/.idea/gait-recognition.iml
+++ b/.idea/gait-recognition.iml
@@ -5,7 +5,7 @@
<excludeFolder url="file://$MODULE_DIR$/venv" />
<excludeFolder url="file://$MODULE_DIR$/data" />
</content>
- <orderEntry type="inheritedJdk" />
+ <orderEntry type="jdk" jdkName="Python 3.9 (gait-recognition)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
diff --git a/config.py b/config.py
index 8a8d93a..7517097 100644
--- a/config.py
+++ b/config.py
@@ -13,7 +13,7 @@ config: Configuration = {
'dataset': {
# Name of dataset (CASIA-B or FVG)
'name': 'CASIA-B',
- # Path to dataset root
+ # Path to dataset root (required)
'root_dir': 'data/CASIA-B-MRCNN/SEG',
# The number of subjects for training
'train_size': 74,
@@ -88,9 +88,13 @@ config: Configuration = {
'model': {
# Model name, used for naming checkpoint
'name': 'RGB-GaitPart',
- # Restoration iteration from checkpoint
- 'restore_iter': 0,
- # Total iteration for training
- 'total_iter': 80000,
+ # Restoration iteration from checkpoint (single model)
+ # 'restore_iter': 0,
+ # Total iteration for training (single model)
+ # 'total_iter': 80000,
+ # Restoration iteration (multiple models, e.g. nm, bg and cl)
+ 'restore_iters': (0, 0, 0),
+ # Total iteration for training (multiple models)
+ 'total_iter': (80_000, 80_000, 80_000),
},
}
diff --git a/models/model.py b/models/model.py
index 2c4e5a0..5ddcf90 100644
--- a/models/model.py
+++ b/models/model.py
@@ -27,7 +27,7 @@ class Model:
model_config: ModelConfiguration,
hyperparameter_config: HyperparameterConfiguration
):
- self.disable_acc = system_config['disable_acc']
+ self.disable_acc = system_config.get('disable_acc', False)
if self.disable_acc:
self.device = torch.device('cpu')
else: # Enable accelerator
@@ -37,17 +37,21 @@ class Model:
print('No accelerator available, fallback to CPU.')
self.device = torch.device('cpu')
- self.save_dir = system_config['save_dir']
+ self.save_dir = system_config.get('save_dir', 'runs')
+ if not os.path.exists(self.save_dir):
+ os.makedirs(self.save_dir)
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
self.log_dir = os.path.join(self.save_dir, 'logs')
- for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir):
+ for dir_ in (self.log_dir, self.checkpoint_dir):
if not os.path.exists(dir_):
os.mkdir(dir_)
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta['restore_iter']
- self.total_iter = self.meta['total_iter']
+ self.curr_iter = self.meta.get('restore_iter', 0)
+ self.total_iter = self.meta.get('total_iter', 80_000)
+ self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
+ self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
self.is_train: bool = True
self.train_size: int = 74
@@ -58,11 +62,9 @@ class Model:
self._gallery_dataset_meta: Optional[Dict[str, List]] = None
self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None
- self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
+ self._model_name: str = self.meta.get('name', 'RGB-GaitPart')
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
- self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
- self._log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
@@ -79,13 +81,26 @@ class Model:
}
@property
- def _signature(self) -> str:
- return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
- self._dataset_sig, str(self.pr), str(self.k)))
+ def _model_sig(self) -> str:
+ return '_'.join((self._model_name, self.curr_iter, self.total_iter))
+
+ @property
+ def _checkpoint_sig(self) -> str:
+ return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig,
+ str(self.pr), str(self.k)))
@property
def _checkpoint_name(self) -> str:
- return os.path.join(self.checkpoint_dir, self._signature)
+ return os.path.join(self.checkpoint_dir, self._checkpoint_sig)
+
+ @property
+ def _log_sig(self) -> str:
+ return '_'.join((self._model_name, self.total_iter, self._hp_sig,
+ self._dataset_sig, str(self.pr), str(self.k)))
+
+ @property
+ def _log_name(self) -> str:
+ return os.path.join(self.log_dir, self._log_sig)
def fit_all(
self,
@@ -95,8 +110,12 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (condition, selector) in dataset_selectors.items():
+ for (curr_iter, total_iter, (condition, selector)) in zip(
+ self.curr_iters, self.total_iters, dataset_selectors.items()
+ ):
print(f'Training model {condition} ...')
+ self.curr_iter = curr_iter
+ self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
dataloader_config
@@ -172,7 +191,6 @@ class Model:
start_time = datetime.now()
if self.curr_iter == self.total_iter:
- self.curr_iter = 0
self.writer.close()
break
@@ -373,7 +391,6 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
- self._log_name = '_'.join((self._log_name, self._dataset_sig))
config: Dict = dataset_config.copy()
name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
@@ -389,10 +406,8 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: Dict = dataloader_config.copy()
- (self.pr, self.k) = config.pop('batch_size')
+ (self.pr, self.k) = config.pop('batch_size', (8, 16))
if self.is_train:
- self._log_name = '_'.join(
- (self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
@@ -424,7 +439,7 @@ class Model:
_config = config.copy()
if popped_keys:
for key in popped_keys:
- _config.pop(key)
+ _config.pop(key, None)
return self._gen_sig(list(_config.values()))
diff --git a/test/model.py b/test/model.py
index 5d60475..8122989 100644
--- a/test/model.py
+++ b/test/model.py
@@ -16,7 +16,7 @@ def test_default_signature():
'runs', 'logs', 'RGB-GaitPart_80000_64_128_128_64_1_2_4_True_True_32_5_'
'3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_'
'0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_8_16')
- assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
+ assert model._checkpoint_sig == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_'
'256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B'
'_74_30_15_3_64_32_8_16')
@@ -34,7 +34,7 @@ def test_default_signature_with_selector():
'3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_'
'0.999_0.001_500_0.9_CASIA-B_74_30_15_3_64_32_bg-0\\d_'
'nm-0\\d_8_16')
- assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
+ assert model._checkpoint_sig == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_'
'256_0.2_0.0001_0.9_0.999_0.001_500_0.9_CASIA-B'
'_74_30_15_3_64_32_bg-0\\d_nm-0\\d_8_16')
diff --git a/utils/configuration.py b/utils/configuration.py
index 71584c0..d7ebc5e 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -64,6 +64,8 @@ class ModelConfiguration(TypedDict):
name: str
restore_iter: int
total_iter: int
+ restore_iters: tuple[int, ...]
+ total_iters: tuple[int, ...]
class Configuration(TypedDict):