diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-05 21:57:07 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-05 21:57:07 +0800 |
commit | d473b9d1d79bf185b1811ce403f82fdd68fb366c (patch) | |
tree | ed46cf1bb43dfbeb3cc44e9a9d91f7b9b5b63357 /models | |
parent | ab29067d6469473481cc73fe42bcaf69d7633a83 (diff) |
Implement checkpoint mechanism
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 57 |
1 files changed, 46 insertions, 11 deletions
diff --git a/models/model.py b/models/model.py index 4354b35..9e52527 100644 --- a/models/model.py +++ b/models/model.py @@ -1,3 +1,4 @@ +import os from typing import Union, Optional import numpy as np @@ -9,7 +10,8 @@ from torch.utils.data.dataloader import default_collate from models import RGBPartNet from utils.configuration import DataloaderConfiguration, \ - HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration + HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ + SystemConfiguration from utils.dataset import CASIAB from utils.sampler import TripletSampler @@ -17,12 +19,17 @@ from utils.sampler import TripletSampler class Model: def __init__( self, + system_config: SystemConfiguration, model_config: ModelConfiguration, hyperparameter_config: HyperparameterConfiguration ): + self.device = system_config['device'] + self.save_dir = system_config['save_dir'] + self.meta = model_config self.hp = hyperparameter_config self.curr_iter = self.meta['restore_iter'] + self.total_iter = self.meta['total_iter'] self.is_train: bool = True self.train_size: int = 74 @@ -34,7 +41,7 @@ class Model: self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' - self.rbg_pn: Optional[RGBPartNet] = None + self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None @@ -63,19 +70,47 @@ class Model: # Prepare for model, optimizer and scheduler hp = self.hp.copy() lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999)) - self.rbg_pn = RGBPartNet(self.train_size, self.in_channels, **hp) - self.optimizer = optim.Adam(self.rbg_pn.parameters(), lr, betas) + self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp) + self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) - self.rbg_pn.train() - self.rbg_pn.apply(self.init_weights) - for iter_i, (x_c1, x_c2) in enumerate(dataloader): - loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label']) + self.rgb_pn.train() + # Init weights at first iter + if self.curr_iter == 0: + self.rgb_pn.apply(self.init_weights) + else: # Load saved state dicts + checkpoint = torch.load(os.path.join(self.save_dir, self.signature)) + iter, loss = checkpoint['iter'], checkpoint['loss'] + print('{0:5d} loss: {1:.3f}'.format(iter, loss)) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + for (x_c1, x_c2) in dataloader: + self.curr_iter += 1 + # Zero the parameter gradients + self.optimizer.zero_grad() + # forward + backward + optimize + loss, metrics = self.rgb_pn(x_c1['clip'], x_c2['clip'], + x_c1['label']) loss.backward() self.optimizer.step() - self.scheduler.step(iter_i) - - if iter_i == self.meta['total_iter']: + # Step scheduler + self.scheduler.step(self.curr_iter) + + if self.curr_iter % 100 == 0: + print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), + '(xrecon = {:f}, pose_sim = {:f},' + ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), + 'lr:', self.scheduler.get_last_lr()) + + if self.curr_iter % 1000 == 0: + torch.save({ + 'iter': self.curr_iter, + 'model_state_dict': self.rgb_pn.state_dict(), + 'optim_state_dict': self.optimizer.state_dict(), + 'loss': loss, + }, os.path.join(self.save_dir, self.signature)) + + if self.curr_iter == self.total_iter: break @staticmethod |