diff options
Diffstat (limited to 'libs/utils.py')
-rw-r--r-- | libs/utils.py | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py index 5e8529b..c019ba9 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ init_csv_logger, csv_logger, tensorboard_logger +from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR @dataclass @@ -33,6 +34,7 @@ class BaseConfig: @dataclass class SchedConfig: sched: None + warmup_iters: int dataset_config: DatasetConfig dataloader_config: DataLoaderConfig @@ -221,15 +223,31 @@ class Trainer(ABC): return last_metrics @staticmethod - @abstractmethod def _configure_scheduler( optims: Iterable[tuple[str, torch.optim.Optimizer]], last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] | tuple[str, None]]: for optim_name, optim in optims: - sched = torch.optim.lr_scheduler._LRScheduler(optim, -1) - yield f"{optim_name}_sched", sched + if sched_config.sched == 'warmup-anneal': + scheduler = LinearWarmupAndCosineAnneal( + optim, + warm_up=sched_config.warmup_iters / num_iters, + T_max=num_iters, + last_epoch=last_iter, + ) + elif sched_config.sched == 'linear': + scheduler = LinearLR( + optim, + num_epochs=num_iters, + last_epoch=last_iter + ) + elif sched_config.sched in {None, '', 'const'}: + scheduler = None + else: + raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'") + + yield f"{optim_name}_sched", scheduler def _custom_init_fn(self, config: BaseConfig): pass |