summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py53
1 files changed, 34 insertions, 19 deletions
diff --git a/models/model.py b/models/model.py
index 3cae788..7373dbb 100644
--- a/models/model.py
+++ b/models/model.py
@@ -27,7 +27,7 @@ class Model:
model_config: ModelConfiguration,
hyperparameter_config: HyperparameterConfiguration
):
- self.disable_acc = system_config['disable_acc']
+ self.disable_acc = system_config.get('disable_acc', False)
if self.disable_acc:
self.device = torch.device('cpu')
else: # Enable accelerator
@@ -37,17 +37,21 @@ class Model:
print('No accelerator available, fallback to CPU.')
self.device = torch.device('cpu')
- self.save_dir = system_config['save_dir']
+ self.save_dir = system_config.get('save_dir', 'runs')
+ if not os.path.exists(self.save_dir):
+ os.makedirs(self.save_dir)
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
self.log_dir = os.path.join(self.save_dir, 'logs')
- for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir):
+ for dir_ in (self.log_dir, self.checkpoint_dir):
if not os.path.exists(dir_):
os.mkdir(dir_)
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta['restore_iter']
- self.total_iter = self.meta['total_iter']
+ self.curr_iter = self.meta.get('restore_iter', 0)
+ self.total_iter = self.meta.get('total_iter', 80_000)
+ self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
+ self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
self.is_train: bool = True
self.train_size: int = 74
@@ -58,11 +62,9 @@ class Model:
self._gallery_dataset_meta: Optional[dict[str, list]] = None
self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
- self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
+ self._model_name: str = self.meta.get('name', 'RGB-GaitPart')
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
- self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
- self._log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
@@ -79,13 +81,26 @@ class Model:
}
@property
- def _signature(self) -> str:
- return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
- self._dataset_sig, str(self.pr), str(self.k)))
+ def _model_sig(self) -> str:
+ return '_'.join((self._model_name, self.curr_iter, self.total_iter))
+
+ @property
+ def _checkpoint_sig(self) -> str:
+ return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig,
+ str(self.pr), str(self.k)))
@property
def _checkpoint_name(self) -> str:
- return os.path.join(self.checkpoint_dir, self._signature)
+ return os.path.join(self.checkpoint_dir, self._checkpoint_sig)
+
+ @property
+ def _log_sig(self) -> str:
+ return '_'.join((self._model_name, self.total_iter, self._hp_sig,
+ self._dataset_sig, str(self.pr), str(self.k)))
+
+ @property
+ def _log_name(self) -> str:
+ return os.path.join(self.log_dir, self._log_sig)
def fit_all(
self,
@@ -95,8 +110,12 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (condition, selector) in dataset_selectors.items():
+ for (curr_iter, total_iter, (condition, selector)) in zip(
+ self.curr_iters, self.total_iters, dataset_selectors.items()
+ ):
print(f'Training model {condition} ...')
+ self.curr_iter = curr_iter
+ self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
dataloader_config
@@ -172,7 +191,6 @@ class Model:
start_time = datetime.now()
if self.curr_iter == self.total_iter:
- self.curr_iter = 0
self.writer.close()
break
@@ -373,7 +391,6 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
- self._log_name = '_'.join((self._log_name, self._dataset_sig))
config: dict = dataset_config.copy()
name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
@@ -389,10 +406,8 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: dict = dataloader_config.copy()
- (self.pr, self.k) = config.pop('batch_size')
+ (self.pr, self.k) = config.pop('batch_size', (8, 16))
if self.is_train:
- self._log_name = '_'.join(
- (self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
@@ -424,7 +439,7 @@ class Model:
_config = config.copy()
if popped_keys:
for key in popped_keys:
- _config.pop(key)
+ _config.pop(key, None)
return self._gen_sig(list(_config.values()))