diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 15:17:06 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 15:17:06 +0800 |
commit | f9b51f152e94896a122325fa2c48b0ce10c881c5 (patch) | |
tree | f777f3f9451318c0ed16428b8302fc3f02a08bb6 /libs | |
parent | ebb2f93ac01f40d00968daaf9a2ad96c24ce7ab3 (diff) |
Add default schedulers
Diffstat (limited to 'libs')
-rw-r--r-- | libs/utils.py | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py index 5e8529b..c019ba9 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ init_csv_logger, csv_logger, tensorboard_logger +from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR @dataclass @@ -33,6 +34,7 @@ class BaseConfig: @dataclass class SchedConfig: sched: None + warmup_iters: int dataset_config: DatasetConfig dataloader_config: DataLoaderConfig @@ -221,15 +223,31 @@ class Trainer(ABC): return last_metrics @staticmethod - @abstractmethod def _configure_scheduler( optims: Iterable[tuple[str, torch.optim.Optimizer]], last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] | tuple[str, None]]: for optim_name, optim in optims: - sched = torch.optim.lr_scheduler._LRScheduler(optim, -1) - yield f"{optim_name}_sched", sched + if sched_config.sched == 'warmup-anneal': + scheduler = LinearWarmupAndCosineAnneal( + optim, + warm_up=sched_config.warmup_iters / num_iters, + T_max=num_iters, + last_epoch=last_iter, + ) + elif sched_config.sched == 'linear': + scheduler = LinearLR( + optim, + num_epochs=num_iters, + last_epoch=last_iter + ) + elif sched_config.sched in {None, '', 'const'}: + scheduler = None + else: + raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'") + + yield f"{optim_name}_sched", scheduler def _custom_init_fn(self, config: BaseConfig): pass |