summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py57
1 files changed, 46 insertions, 11 deletions
diff --git a/models/model.py b/models/model.py
index 4354b35..9e52527 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,3 +1,4 @@
+import os
from typing import Union, Optional
import numpy as np
@@ -9,7 +10,8 @@ from torch.utils.data.dataloader import default_collate
from models import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
- HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration
+ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
+ SystemConfiguration
from utils.dataset import CASIAB
from utils.sampler import TripletSampler
@@ -17,12 +19,17 @@ from utils.sampler import TripletSampler
class Model:
def __init__(
self,
+ system_config: SystemConfiguration,
model_config: ModelConfiguration,
hyperparameter_config: HyperparameterConfiguration
):
+ self.device = system_config['device']
+ self.save_dir = system_config['save_dir']
+
self.meta = model_config
self.hp = hyperparameter_config
self.curr_iter = self.meta['restore_iter']
+ self.total_iter = self.meta['total_iter']
self.is_train: bool = True
self.train_size: int = 74
@@ -34,7 +41,7 @@ class Model:
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
- self.rbg_pn: Optional[RGBPartNet] = None
+ self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
@@ -63,19 +70,47 @@ class Model:
# Prepare for model, optimizer and scheduler
hp = self.hp.copy()
lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
- self.rbg_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
- self.optimizer = optim.Adam(self.rbg_pn.parameters(), lr, betas)
+ self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
+ self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
- self.rbg_pn.train()
- self.rbg_pn.apply(self.init_weights)
- for iter_i, (x_c1, x_c2) in enumerate(dataloader):
- loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label'])
+ self.rgb_pn.train()
+ # Init weights at first iter
+ if self.curr_iter == 0:
+ self.rgb_pn.apply(self.init_weights)
+ else: # Load saved state dicts
+ checkpoint = torch.load(os.path.join(self.save_dir, self.signature))
+ iter, loss = checkpoint['iter'], checkpoint['loss']
+ print('{0:5d} loss: {1:.3f}'.format(iter, loss))
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ for (x_c1, x_c2) in dataloader:
+ self.curr_iter += 1
+ # Zero the parameter gradients
+ self.optimizer.zero_grad()
+ # forward + backward + optimize
+ loss, metrics = self.rgb_pn(x_c1['clip'], x_c2['clip'],
+ x_c1['label'])
loss.backward()
self.optimizer.step()
- self.scheduler.step(iter_i)
-
- if iter_i == self.meta['total_iter']:
+ # Step scheduler
+ self.scheduler.step(self.curr_iter)
+
+ if self.curr_iter % 100 == 0:
+ print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
+ '(xrecon = {:f}, pose_sim = {:f},'
+ ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
+ 'lr:', self.scheduler.get_last_lr())
+
+ if self.curr_iter % 1000 == 0:
+ torch.save({
+ 'iter': self.curr_iter,
+ 'model_state_dict': self.rgb_pn.state_dict(),
+ 'optim_state_dict': self.optimizer.state_dict(),
+ 'loss': loss,
+ }, os.path.join(self.save_dir, self.signature))
+
+ if self.curr_iter == self.total_iter:
break
@staticmethod