diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/__init__.py | 5 | ||||
-rw-r--r-- | models/model.py | 6 | ||||
-rw-r--r-- | models/rgb_part_net.py | 6 |
3 files changed, 8 insertions, 9 deletions
diff --git a/models/__init__.py b/models/__init__.py index c1b9fe8..7040c63 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1 @@ -from .auto_encoder import AutoEncoder -from .hpm import HorizontalPyramidMatching -from .part_net import PartNet -from .rgb_part_net import RGBPartNet +from .model import Model
\ No newline at end of file diff --git a/models/model.py b/models/model.py index bf8b5fb..1dc0f23 100644 --- a/models/model.py +++ b/models/model.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter -from models import RGBPartNet +from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration @@ -116,7 +116,7 @@ class Model: loss.backward() self.optimizer.step() # Step scheduler - self.scheduler.step(self.curr_iter) + self.scheduler.step() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter) @@ -129,7 +129,7 @@ class Model: print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), - 'lr:', self.scheduler.get_last_lr()) + 'lr:', self.scheduler.get_last_lr()[0]) if self.curr_iter % 1000 == 0: torch.save({ diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 73d5952..3037da0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models import AutoEncoder, HorizontalPyramidMatching, PartNet +from models.auto_encoder import AutoEncoder +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from utils.triplet_loss import BatchAllTripletLoss @@ -117,7 +119,7 @@ class RGBPartNet(nn.Module): # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) - pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) + pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10 cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), |