diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-02 19:51:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-02 19:51:52 +0800 |
commit | 5ef8caf8472c0bbc28bf40797660c620aee49ac6 (patch) | |
tree | a2d86bb3ff761cf689e41bbe70defcda7f6a8aad /models/model.py | |
parent | 7489bf339e13282b06a78659f8b8fe9d505e82dd (diff) | |
parent | 02780c31385af7e1103448bd1994012ac95dd2bb (diff) |
Merge branch 'master' into python3.8
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/models/model.py b/models/model.py index d49e701..3f6a49d 100644 --- a/models/model.py +++ b/models/model.py @@ -185,11 +185,14 @@ class Model: ], **optim_hp) sched_final_gamma = sched_hp.get('final_gamma', 0.001) sched_start_step = sched_hp.get('start_step', 15_000) + all_step = self.total_iter - sched_start_step def lr_lambda(epoch): - passed_step = epoch - sched_start_step - all_step = self.total_iter - sched_start_step - return sched_final_gamma ** (passed_step / all_step) + if epoch > sched_start_step: + passed_step = epoch - sched_start_step + return sched_final_gamma ** (passed_step / all_step) + else: + return 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lr_lambda, lr_lambda, lr_lambda, lr_lambda ]) @@ -279,13 +282,14 @@ class Model: 'Embedding/PartNet norm', mean_pa_norm, self.k, self.pr * self.k, self.curr_iter ) + # Learning rate + lrs = self.scheduler.get_last_lr() + # Write learning rates + self.writer.add_scalar( + 'Learning rate', lrs[0], self.curr_iter + ) if self.curr_iter % 100 == 0: - lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images |