summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py75
-rw-r--r--models/rgb_part_net.py2
2 files changed, 49 insertions, 28 deletions
diff --git a/models/model.py b/models/model.py
index 9e52527..3842844 100644
--- a/models/model.py
+++ b/models/model.py
@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
+from torch.utils.tensorboard import SummaryWriter
from models import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
@@ -25,6 +26,11 @@ class Model:
):
self.device = system_config['device']
self.save_dir = system_config['save_dir']
+ self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
+ self.log_dir = os.path.join(self.save_dir, 'logs')
+ for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir):
+ if not os.path.exists(dir_):
+ os.mkdir(dir_)
self.meta = model_config
self.hp = hyperparameter_config
@@ -40,24 +46,22 @@ class Model:
self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
+ self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
+ self.log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
+ self.writer: Optional[SummaryWriter] = None
@property
def signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
- self._dataset_sig, str(self.batch_size)))
+ self._dataset_sig, str(self.pr), str(self.k)))
@property
- def batch_size(self) -> int:
- if self.is_train:
- if self.pr and self.k:
- return self.pr * self.k
- raise AttributeError('No dataset loaded')
- else:
- return 1
+ def checkpoint_name(self) -> str:
+ return os.path.join(self.checkpoint_dir, self.signature)
def fit(
self,
@@ -73,29 +77,39 @@ class Model:
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
+ self.writer = SummaryWriter(self.log_name)
self.rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
- checkpoint = torch.load(os.path.join(self.save_dir, self.signature))
- iter, loss = checkpoint['iter'], checkpoint['loss']
- print('{0:5d} loss: {1:.3f}'.format(iter, loss))
+ checkpoint = torch.load(self.checkpoint_name)
+ iter_, loss = checkpoint['iter'], checkpoint['loss']
+ print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
- for (x_c1, x_c2) in dataloader:
+
+ for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
- loss, metrics = self.rgb_pn(x_c1['clip'], x_c2['clip'],
- x_c1['label'])
+ loss, metrics = self.rgb_pn(
+ batch_c1['clip'], batch_c2['clip'], batch_c1['label']
+ )
loss.backward()
self.optimizer.step()
# Step scheduler
self.scheduler.step(self.curr_iter)
+ # Write losses to TensorBoard
+ self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter)
+ self.writer.add_scalars('Loss/details', dict(zip([
+ 'Cross reconstruction loss', 'Pose similarity loss',
+ 'Canonical consistency loss', 'Batch All triplet loss'
+ ], metrics)), self.curr_iter)
+
if self.curr_iter % 100 == 0:
print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
@@ -108,9 +122,10 @@ class Model:
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
- }, os.path.join(self.save_dir, self.signature))
+ }, self.checkpoint_name)
if self.curr_iter == self.total_iter:
+ self.writer.close()
break
@staticmethod
@@ -135,6 +150,7 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
+ self.log_name = '_'.join((self.log_name, self._dataset_sig))
config: dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
@@ -152,6 +168,7 @@ class Model:
config: dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
+ self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
@@ -178,19 +195,23 @@ class Model:
return default_collate(_batch[0]), default_collate(_batch[1])
- @staticmethod
- def _make_signature(config: dict,
+ def _make_signature(self,
+ config: dict,
popped_keys: Optional[list] = None) -> str:
_config = config.copy()
- for (key, value) in config.items():
- if popped_keys and key in popped_keys:
+ if popped_keys:
+ for key in popped_keys:
_config.pop(key)
- continue
- if isinstance(value, str):
- pass
- elif isinstance(value, (tuple, list)):
- _config[key] = '_'.join([str(v) for v in value])
- else:
- _config[key] = str(value)
- return '_'.join(_config.values())
+ return self._gen_sig(list(_config.values()))
+
+ def _gen_sig(self, values: Union[tuple, list, str, int, float]) -> str:
+ strings = []
+ for v in values:
+ if isinstance(v, str):
+ strings.append(v)
+ elif isinstance(v, (tuple, list)):
+ strings.append(self._gen_sig(v))
+ else:
+ strings.append(str(v))
+ return '_'.join(strings)
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index a58be39..73d5952 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -78,7 +78,7 @@ class RGBPartNet(nn.Module):
batch_all_triplet_loss = self.ba_triplet_loss(x, y)
losses = (*losses, batch_all_triplet_loss)
loss = torch.sum(torch.stack(losses))
- return loss, (loss.item() for loss in losses)
+ return loss, [loss.item() for loss in losses]
else:
return x