From 543a35d814754c86ccafa1243bece387c1a780d6 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 2 Mar 2021 19:22:25 +0800 Subject: Fix bugs in new scheduler --- models/model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/models/model.py b/models/model.py index 497a0ea..3242141 100644 --- a/models/model.py +++ b/models/model.py @@ -185,11 +185,14 @@ class Model: ], **optim_hp) sched_final_gamma = sched_hp.get('final_gamma', 0.001) sched_start_step = sched_hp.get('start_step', 15_000) + all_step = self.total_iter - sched_start_step def lr_lambda(epoch): - passed_step = epoch - sched_start_step - all_step = self.total_iter - sched_start_step - return sched_final_gamma ** (passed_step / all_step) + if epoch > sched_start_step: + passed_step = epoch - sched_start_step + return sched_final_gamma ** (passed_step / all_step) + else: + return 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lr_lambda, lr_lambda, lr_lambda, lr_lambda ]) -- cgit v1.2.3