From f9b51f152e94896a122325fa2c48b0ce10c881c5 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Aug 2022 15:17:06 +0800 Subject: Add default schedulers --- libs/utils.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) (limited to 'libs') diff --git a/libs/utils.py b/libs/utils.py index 5e8529b..c019ba9 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ init_csv_logger, csv_logger, tensorboard_logger +from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR @dataclass @@ -33,6 +34,7 @@ class BaseConfig: @dataclass class SchedConfig: sched: None + warmup_iters: int dataset_config: DatasetConfig dataloader_config: DataLoaderConfig @@ -221,15 +223,31 @@ class Trainer(ABC): return last_metrics @staticmethod - @abstractmethod def _configure_scheduler( optims: Iterable[tuple[str, torch.optim.Optimizer]], last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] | tuple[str, None]]: for optim_name, optim in optims: - sched = torch.optim.lr_scheduler._LRScheduler(optim, -1) - yield f"{optim_name}_sched", sched + if sched_config.sched == 'warmup-anneal': + scheduler = LinearWarmupAndCosineAnneal( + optim, + warm_up=sched_config.warmup_iters / num_iters, + T_max=num_iters, + last_epoch=last_iter, + ) + elif sched_config.sched == 'linear': + scheduler = LinearLR( + optim, + num_epochs=num_iters, + last_epoch=last_iter + ) + elif sched_config.sched in {None, '', 'const'}: + scheduler = None + else: + raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'") + + yield f"{optim_name}_sched", scheduler def _custom_init_fn(self, config: BaseConfig): pass -- cgit v1.2.3