From a66e5a456a2b036b2b787371da3160efc559031f Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 8 Apr 2021 12:57:40 +0800 Subject: Add stop step for scheduler --- config.py | 12 ++++++++---- models/model.py | 27 +++++++++++++++++---------- utils/configuration.py | 2 ++ 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/config.py b/config.py index df7c64c..726d06e 100644 --- a/config.py +++ b/config.py @@ -41,7 +41,7 @@ config: Configuration = { # Batch size (pr, k) # `pr` denotes number of persons # `k` denotes number of sequences per person - 'batch_size': (4, 6), + 'batch_size': (4, 5), # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled @@ -88,18 +88,22 @@ config: Configuration = { # Local parameters (override global ones) # 'auto_encoder': { - # 'weight_decay': 0.001 + # 'lr': 1e-3 # }, }, 'scheduler': { # Step start to decay 'start_step': 500, + # Step stop decaying + # 'stop_step': 30_000, # Multiplicative factor of decay in the end 'final_gamma': 0.01, # Local parameters (override global ones) - # 'hpm': { - # 'final_gamma': 0.001 + # 'auto_encoder': { + # 'start_step': 0, + # 'stop_step': 500, + # 'final_gamma': 0.5 # } } }, diff --git a/models/model.py b/models/model.py index 9dbf609..45067e6 100644 --- a/models/model.py +++ b/models/model.py @@ -212,24 +212,31 @@ class Model: ], **optim_hp) # Scheduler - start_step = sched_hp.get('start_step', 15_000) + start_step = sched_hp.get('start_step', 0) + stop_step = sched_hp.get('stop_step', self.total_iter) final_gamma = sched_hp.get('final_gamma', 0.001) ae_start_step = ae_sched_hp.get('start_step', start_step) + ae_stop_step = ae_sched_hp.get('stop_step', stop_step) ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) - ae_all_step = self.total_iter - ae_start_step + ae_all_step = ae_stop_step - ae_start_step hpm_start_step = hpm_sched_hp.get('start_step', start_step) + hpm_stop_step = hpm_sched_hp.get('stop_step', stop_step) hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) - hpm_all_step = self.total_iter - hpm_start_step + hpm_all_step = hpm_stop_step - hpm_start_step pn_start_step = pn_sched_hp.get('start_step', start_step) + pn_stop_step = pn_sched_hp.get('stop_step', stop_step) pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) - pn_all_step = self.total_iter - pn_start_step + pn_all_step = pn_stop_step - pn_start_step self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) - if t > ae_start_step else 1, - lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) - if t > hpm_start_step else 1, - lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) - if t > pn_start_step else 1, + lambda t: 1 if t <= ae_start_step + else ae_final_gamma ** ((t - ae_start_step) / ae_all_step) + if ae_start_step < t <= ae_stop_step else ae_final_gamma, + lambda t: 1 if t <= hpm_start_step + else hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) + if hpm_start_step < t <= hpm_stop_step else hpm_final_gamma, + lambda t: 1 if t <= pn_start_step + else pn_final_gamma ** ((t - pn_start_step) / pn_all_step) + if pn_start_step < t <= pn_stop_step else pn_final_gamma, ]) self.writer = SummaryWriter(self._log_name) diff --git a/utils/configuration.py b/utils/configuration.py index 608d413..579d2f2 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -66,11 +66,13 @@ class OptimizerHPConfiguration(TypedDict): class SubSchedulerHPConfiguration(TypedDict): start_step: int + stop_step: int final_gamma: float class SchedulerHPConfiguration(TypedDict): start_step: int + stop_step: int final_gamma: float auto_encoder: SubSchedulerHPConfiguration hpm: SubSchedulerHPConfiguration -- cgit v1.2.3