From 4f81e3f44ed60459ddbd311cb7356932974c800d Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Wed, 10 Feb 2021 11:41:49 +0800
Subject: Save scheduler state_dict

---
 models/model.py | 2 ++
 1 file changed, 2 insertions(+)

(limited to 'models')

diff --git a/models/model.py b/models/model.py
index 70c43b3..9cb6a8e 100644
--- a/models/model.py
+++ b/models/model.py
@@ -177,6 +177,7 @@ class Model:
             print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
             self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
             self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+            self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
 
         # Training start
         start_time = datetime.now()
@@ -249,6 +250,7 @@ class Model:
                     'iter': self.curr_iter,
                     'model_state_dict': self.rgb_pn.state_dict(),
                     'optim_state_dict': self.optimizer.state_dict(),
+                    'sched_state_dict': self.scheduler.state_dict(),
                     'loss': loss,
                 }, self._checkpoint_name)
 
-- 
cgit v1.2.3