summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index 497a0ea..3242141 100644
--- a/models/model.py
+++ b/models/model.py
@@ -185,11 +185,14 @@ class Model:
], **optim_hp)
sched_final_gamma = sched_hp.get('final_gamma', 0.001)
sched_start_step = sched_hp.get('start_step', 15_000)
+ all_step = self.total_iter - sched_start_step
def lr_lambda(epoch):
- passed_step = epoch - sched_start_step
- all_step = self.total_iter - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
+ if epoch > sched_start_step:
+ passed_step = epoch - sched_start_step
+ return sched_final_gamma ** (passed_step / all_step)
+ else:
+ return 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
lr_lambda, lr_lambda, lr_lambda, lr_lambda
])