summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-10 13:45:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-10 13:45:38 +0800
commitba705ac60f463b034782ba03796dbfa2776a5631 (patch)
tree984a50254777583fd384bd4d72acf12843ad8de0 /models/model.py
parent58ef39d75098bce92654492e09edf1e83033d0c8 (diff)
parent2f7c09fb4fb985db1cbf6e2bdc6622a2c51ebfc3 (diff)
Merge branch 'master' into python3.8
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/models/model.py b/models/model.py
index b5daa54..a8843c5 100644
--- a/models/model.py
+++ b/models/model.py
@@ -177,6 +177,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
start_time = datetime.now()
@@ -249,6 +250,7 @@ class Model:
'iter': self.curr_iter,
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
+ 'sched_state_dict': self.scheduler.state_dict(),
'loss': loss,
}, self._checkpoint_name)