summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/__init__.py5
-rw-r--r--models/model.py6
-rw-r--r--models/rgb_part_net.py6
3 files changed, 8 insertions, 9 deletions
diff --git a/models/__init__.py b/models/__init__.py
index c1b9fe8..7040c63 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,4 +1 @@
-from .auto_encoder import AutoEncoder
-from .hpm import HorizontalPyramidMatching
-from .part_net import PartNet
-from .rgb_part_net import RGBPartNet
+from .model import Model \ No newline at end of file
diff --git a/models/model.py b/models/model.py
index bf8b5fb..1dc0f23 100644
--- a/models/model.py
+++ b/models/model.py
@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
@@ -116,7 +116,7 @@ class Model:
loss.backward()
self.optimizer.step()
# Step scheduler
- self.scheduler.step(self.curr_iter)
+ self.scheduler.step()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter)
@@ -129,7 +129,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
- 'lr:', self.scheduler.get_last_lr())
+ 'lr:', self.scheduler.get_last_lr()[0])
if self.curr_iter % 1000 == 0:
torch.save({
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 73d5952..3037da0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from models.auto_encoder import AutoEncoder
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from utils.triplet_loss import BatchAllTripletLoss
@@ -117,7 +119,7 @@ class RGBPartNet(nn.Module):
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
- pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2)
+ pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),