From 8ee391b65e2b48d777a268749f54b3aa9e4b9142 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Wed, 13 Jan 2021 10:59:59 +0800
Subject: Add multiple checkpoints for different model and set default config
 value

---
 models/model.py | 53 ++++++++++++++++++++++++++++++++++-------------------
 1 file changed, 34 insertions(+), 19 deletions(-)

(limited to 'models')

diff --git a/models/model.py b/models/model.py
index 3cae788..7373dbb 100644
--- a/models/model.py
+++ b/models/model.py
@@ -27,7 +27,7 @@ class Model:
             model_config: ModelConfiguration,
             hyperparameter_config: HyperparameterConfiguration
     ):
-        self.disable_acc = system_config['disable_acc']
+        self.disable_acc = system_config.get('disable_acc', False)
         if self.disable_acc:
             self.device = torch.device('cpu')
         else:  # Enable accelerator
@@ -37,17 +37,21 @@ class Model:
                 print('No accelerator available, fallback to CPU.')
                 self.device = torch.device('cpu')
 
-        self.save_dir = system_config['save_dir']
+        self.save_dir = system_config.get('save_dir', 'runs')
+        if not os.path.exists(self.save_dir):
+            os.makedirs(self.save_dir)
         self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
         self.log_dir = os.path.join(self.save_dir, 'logs')
-        for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir):
+        for dir_ in (self.log_dir, self.checkpoint_dir):
             if not os.path.exists(dir_):
                 os.mkdir(dir_)
 
         self.meta = model_config
         self.hp = hyperparameter_config
-        self.curr_iter = self.meta['restore_iter']
-        self.total_iter = self.meta['total_iter']
+        self.curr_iter = self.meta.get('restore_iter', 0)
+        self.total_iter = self.meta.get('total_iter', 80_000)
+        self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
+        self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
 
         self.is_train: bool = True
         self.train_size: int = 74
@@ -58,11 +62,9 @@ class Model:
         self._gallery_dataset_meta: Optional[dict[str, list]] = None
         self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
 
-        self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
+        self._model_name: str = self.meta.get('name', 'RGB-GaitPart')
         self._hp_sig: str = self._make_signature(self.hp)
         self._dataset_sig: str = 'undefined'
-        self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
-        self._log_name: str = os.path.join(self.log_dir, self._log_sig)
 
         self.rgb_pn: Optional[RGBPartNet] = None
         self.optimizer: Optional[optim.Adam] = None
@@ -79,13 +81,26 @@ class Model:
         }
 
     @property
-    def _signature(self) -> str:
-        return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
-                         self._dataset_sig, str(self.pr), str(self.k)))
+    def _model_sig(self) -> str:
+        return '_'.join((self._model_name, self.curr_iter, self.total_iter))
+
+    @property
+    def _checkpoint_sig(self) -> str:
+        return '_'.join((self._model_sig, self._hp_sig, self._dataset_sig,
+                         str(self.pr), str(self.k)))
 
     @property
     def _checkpoint_name(self) -> str:
-        return os.path.join(self.checkpoint_dir, self._signature)
+        return os.path.join(self.checkpoint_dir, self._checkpoint_sig)
+
+    @property
+    def _log_sig(self) -> str:
+        return '_'.join((self._model_name, self.total_iter, self._hp_sig,
+                         self._dataset_sig, str(self.pr), str(self.k)))
+
+    @property
+    def _log_name(self) -> str:
+        return os.path.join(self.log_dir, self._log_sig)
 
     def fit_all(
             self,
@@ -95,8 +110,12 @@ class Model:
             ],
             dataloader_config: DataloaderConfiguration,
     ):
-        for (condition, selector) in dataset_selectors.items():
+        for (curr_iter, total_iter, (condition, selector)) in zip(
+                self.curr_iters, self.total_iters, dataset_selectors.items()
+        ):
             print(f'Training model {condition} ...')
+            self.curr_iter = curr_iter
+            self.total_iter = total_iter
             self.fit(
                 dict(**dataset_config, **{'selector': selector}),
                 dataloader_config
@@ -172,7 +191,6 @@ class Model:
                 start_time = datetime.now()
 
             if self.curr_iter == self.total_iter:
-                self.curr_iter = 0
                 self.writer.close()
                 break
 
@@ -373,7 +391,6 @@ class Model:
             dataset_config,
             popped_keys=['root_dir', 'cache_on']
         )
-        self._log_name = '_'.join((self._log_name, self._dataset_sig))
         config: dict = dataset_config.copy()
         name = config.pop('name', 'CASIA-B')
         if name == 'CASIA-B':
@@ -389,10 +406,8 @@ class Model:
             dataloader_config: DataloaderConfiguration
     ) -> DataLoader:
         config: dict = dataloader_config.copy()
-        (self.pr, self.k) = config.pop('batch_size')
+        (self.pr, self.k) = config.pop('batch_size', (8, 16))
         if self.is_train:
-            self._log_name = '_'.join(
-                (self._log_name, str(self.pr), str(self.k)))
             triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
             return DataLoader(dataset,
                               batch_sampler=triplet_sampler,
@@ -424,7 +439,7 @@ class Model:
         _config = config.copy()
         if popped_keys:
             for key in popped_keys:
-                _config.pop(key)
+                _config.pop(key, None)
 
         return self._gen_sig(list(_config.values()))
 
-- 
cgit v1.2.3