From 4a284084c253b9114fc02e1782962556ff113761 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 7 Jan 2021 18:37:43 +0800 Subject: Add typical training script and some bug fixes 1. Resolve deprecated scheduler stepping issue 2. Make losses in the same scale(replace mean with sum in separate triplet loss, enlarge pose similarity loss 10x) 3. Add ReLU when compute distance in triplet loss 4. Remove classes except Model from `models` package init --- models/__init__.py | 5 +---- models/model.py | 6 +++--- models/rgb_part_net.py | 6 ++++-- 3 files changed, 8 insertions(+), 9 deletions(-) (limited to 'models') diff --git a/models/__init__.py b/models/__init__.py index c1b9fe8..7040c63 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1 @@ -from .auto_encoder import AutoEncoder -from .hpm import HorizontalPyramidMatching -from .part_net import PartNet -from .rgb_part_net import RGBPartNet +from .model import Model \ No newline at end of file diff --git a/models/model.py b/models/model.py index bf8b5fb..1dc0f23 100644 --- a/models/model.py +++ b/models/model.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter -from models import RGBPartNet +from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration @@ -116,7 +116,7 @@ class Model: loss.backward() self.optimizer.step() # Step scheduler - self.scheduler.step(self.curr_iter) + self.scheduler.step() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter) @@ -129,7 +129,7 @@ class Model: print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), - 'lr:', self.scheduler.get_last_lr()) + 'lr:', self.scheduler.get_last_lr()[0]) if self.curr_iter % 1000 == 0: torch.save({ diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 73d5952..3037da0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models import AutoEncoder, HorizontalPyramidMatching, PartNet +from models.auto_encoder import AutoEncoder +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from utils.triplet_loss import BatchAllTripletLoss @@ -117,7 +119,7 @@ class RGBPartNet(nn.Module): # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) - pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) + pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10 cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), -- cgit v1.2.3