From 4befe59046fb3adf8ef8eb589999a74cf7136ff6 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 6 Jan 2021 21:26:56 +0800 Subject: Add TensorBoard support --- models/model.py | 75 ++++++++++++++++++++++++++++++++------------------ models/rgb_part_net.py | 2 +- 2 files changed, 49 insertions(+), 28 deletions(-) (limited to 'models') diff --git a/models/model.py b/models/model.py index 9e52527..3842844 100644 --- a/models/model.py +++ b/models/model.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate +from torch.utils.tensorboard import SummaryWriter from models import RGBPartNet from utils.configuration import DataloaderConfiguration, \ @@ -25,6 +26,11 @@ class Model: ): self.device = system_config['device'] self.save_dir = system_config['save_dir'] + self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') + self.log_dir = os.path.join(self.save_dir, 'logs') + for dir_ in (self.save_dir, self.log_dir, self.checkpoint_dir): + if not os.path.exists(dir_): + os.mkdir(dir_) self.meta = model_config self.hp = hyperparameter_config @@ -40,24 +46,22 @@ class Model: self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' + self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) + self.log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None + self.writer: Optional[SummaryWriter] = None @property def signature(self) -> str: return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, - self._dataset_sig, str(self.batch_size))) + self._dataset_sig, str(self.pr), str(self.k))) @property - def batch_size(self) -> int: - if self.is_train: - if self.pr and self.k: - return self.pr * self.k - raise AttributeError('No dataset loaded') - else: - return 1 + def checkpoint_name(self) -> str: + return os.path.join(self.checkpoint_dir, self.signature) def fit( self, @@ -73,29 +77,39 @@ class Model: self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp) self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) + self.writer = SummaryWriter(self.log_name) self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts - checkpoint = torch.load(os.path.join(self.save_dir, self.signature)) - iter, loss = checkpoint['iter'], checkpoint['loss'] - print('{0:5d} loss: {1:.3f}'.format(iter, loss)) + checkpoint = torch.load(self.checkpoint_name) + iter_, loss = checkpoint['iter'], checkpoint['loss'] + print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) - for (x_c1, x_c2) in dataloader: + + for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize - loss, metrics = self.rgb_pn(x_c1['clip'], x_c2['clip'], - x_c1['label']) + loss, metrics = self.rgb_pn( + batch_c1['clip'], batch_c2['clip'], batch_c1['label'] + ) loss.backward() self.optimizer.step() # Step scheduler self.scheduler.step(self.curr_iter) + # Write losses to TensorBoard + self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter) + self.writer.add_scalars('Loss/details', dict(zip([ + 'Cross reconstruction loss', 'Pose similarity loss', + 'Canonical consistency loss', 'Batch All triplet loss' + ], metrics)), self.curr_iter) + if self.curr_iter % 100 == 0: print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' @@ -108,9 +122,10 @@ class Model: 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'loss': loss, - }, os.path.join(self.save_dir, self.signature)) + }, self.checkpoint_name) if self.curr_iter == self.total_iter: + self.writer.close() break @staticmethod @@ -135,6 +150,7 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) + self.log_name = '_'.join((self.log_name, self._dataset_sig)) config: dict = dataset_config.copy() name = config.pop('name') if name == 'CASIA-B': @@ -152,6 +168,7 @@ class Model: config: dict = dataloader_config.copy() if self.is_train: (self.pr, self.k) = config.pop('batch_size') + self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, @@ -178,19 +195,23 @@ class Model: return default_collate(_batch[0]), default_collate(_batch[1]) - @staticmethod - def _make_signature(config: dict, + def _make_signature(self, + config: dict, popped_keys: Optional[list] = None) -> str: _config = config.copy() - for (key, value) in config.items(): - if popped_keys and key in popped_keys: + if popped_keys: + for key in popped_keys: _config.pop(key) - continue - if isinstance(value, str): - pass - elif isinstance(value, (tuple, list)): - _config[key] = '_'.join([str(v) for v in value]) - else: - _config[key] = str(value) - return '_'.join(_config.values()) + return self._gen_sig(list(_config.values())) + + def _gen_sig(self, values: Union[tuple, list, str, int, float]) -> str: + strings = [] + for v in values: + if isinstance(v, str): + strings.append(v) + elif isinstance(v, (tuple, list)): + strings.append(self._gen_sig(v)) + else: + strings.append(str(v)) + return '_'.join(strings) diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index a58be39..73d5952 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -78,7 +78,7 @@ class RGBPartNet(nn.Module): batch_all_triplet_loss = self.ba_triplet_loss(x, y) losses = (*losses, batch_all_triplet_loss) loss = torch.sum(torch.stack(losses)) - return loss, (loss.item() for loss in losses) + return loss, [loss.item() for loss in losses] else: return x -- cgit v1.2.3