From 543a35d814754c86ccafa1243bece387c1a780d6 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Tue, 2 Mar 2021 19:22:25 +0800
Subject: Fix bugs in new scheduler

---
 models/model.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

(limited to 'models')

diff --git a/models/model.py b/models/model.py
index 497a0ea..3242141 100644
--- a/models/model.py
+++ b/models/model.py
@@ -185,11 +185,14 @@ class Model:
         ], **optim_hp)
         sched_final_gamma = sched_hp.get('final_gamma', 0.001)
         sched_start_step = sched_hp.get('start_step', 15_000)
+        all_step = self.total_iter - sched_start_step
 
         def lr_lambda(epoch):
-            passed_step = epoch - sched_start_step
-            all_step = self.total_iter - sched_start_step
-            return sched_final_gamma ** (passed_step / all_step)
+            if epoch > sched_start_step:
+                passed_step = epoch - sched_start_step
+                return sched_final_gamma ** (passed_step / all_step)
+            else:
+                return 1
         self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
             lr_lambda, lr_lambda, lr_lambda, lr_lambda
         ])
-- 
cgit v1.2.3


From 02780c31385af7e1103448bd1994012ac95dd2bb Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Tue, 2 Mar 2021 19:26:11 +0800
Subject: Record learning rate every step

---
 models/model.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

(limited to 'models')

diff --git a/models/model.py b/models/model.py
index 3242141..acccbff 100644
--- a/models/model.py
+++ b/models/model.py
@@ -282,13 +282,14 @@ class Model:
                 'Embedding/PartNet norm', mean_pa_norm,
                 self.k, self.pr * self.k, self.curr_iter
             )
+            # Learning rate
+            lrs = self.scheduler.get_last_lr()
+            # Write learning rates
+            self.writer.add_scalar(
+                'Learning rate', lrs[0], self.curr_iter
+            )
 
             if self.curr_iter % 100 == 0:
-                lrs = self.scheduler.get_last_lr()
-                # Write learning rates
-                self.writer.add_scalar(
-                    'Learning rate', lrs[0], self.curr_iter
-                )
                 # Write disentangled images
                 if self.image_log_on:
                     i_a, i_c, i_p = images
-- 
cgit v1.2.3