summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-02 19:52:03 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-02 19:52:03 +0800
commit419d1de5669d55426afbeff82c41c25ba56745d7 (patch)
treef179958d9f5c6f11260b891dd77e4b8aef8254ea /models/model.py
parentf0f3d5fbc3306b00c5c59a8baccd3cb4fab77fed (diff)
parent5ef8caf8472c0bbc28bf40797660c620aee49ac6 (diff)
Merge branch 'python3.8' into python3.7
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/models/model.py b/models/model.py
index 2fb2b39..0eb9823 100644
--- a/models/model.py
+++ b/models/model.py
@@ -182,11 +182,14 @@ class Model:
], **optim_hp)
sched_final_gamma = sched_hp.get('final_gamma', 0.001)
sched_start_step = sched_hp.get('start_step', 15_000)
+ all_step = self.total_iter - sched_start_step
def lr_lambda(epoch):
- passed_step = epoch - sched_start_step
- all_step = self.total_iter - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
+ if epoch > sched_start_step:
+ passed_step = epoch - sched_start_step
+ return sched_final_gamma ** (passed_step / all_step)
+ else:
+ return 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
lr_lambda, lr_lambda, lr_lambda, lr_lambda
])
@@ -276,13 +279,14 @@ class Model:
'Embedding/PartNet norm', mean_pa_norm,
self.k, self.pr * self.k, self.curr_iter
)
+ # Learning rate
+ lrs = self.scheduler.get_last_lr()
+ # Write learning rates
+ self.writer.add_scalar(
+ 'Learning rate', lrs[0], self.curr_iter
+ )
if self.curr_iter % 100 == 0:
- lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images