summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index bf8b5fb..1dc0f23 100644
--- a/models/model.py
+++ b/models/model.py
@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
@@ -116,7 +116,7 @@ class Model:
loss.backward()
self.optimizer.step()
# Step scheduler
- self.scheduler.step(self.curr_iter)
+ self.scheduler.step()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter)
@@ -129,7 +129,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
- 'lr:', self.scheduler.get_last_lr())
+ 'lr:', self.scheduler.get_last_lr()[0])
if self.curr_iter % 1000 == 0:
torch.save({