From f9b51f152e94896a122325fa2c48b0ce10c881c5 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Aug 2022 15:17:06 +0800 Subject: Add default schedulers --- libs/utils.py | 24 +++++++++++++++++++++--- simclr/main.py | 34 ---------------------------------- supervised/baseline.py | 24 +----------------------- 3 files changed, 22 insertions(+), 60 deletions(-) diff --git a/libs/utils.py b/libs/utils.py index 5e8529b..c019ba9 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ init_csv_logger, csv_logger, tensorboard_logger +from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR @dataclass @@ -33,6 +34,7 @@ class BaseConfig: @dataclass class SchedConfig: sched: None + warmup_iters: int dataset_config: DatasetConfig dataloader_config: DataLoaderConfig @@ -221,15 +223,31 @@ class Trainer(ABC): return last_metrics @staticmethod - @abstractmethod def _configure_scheduler( optims: Iterable[tuple[str, torch.optim.Optimizer]], last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] | tuple[str, None]]: for optim_name, optim in optims: - sched = torch.optim.lr_scheduler._LRScheduler(optim, -1) - yield f"{optim_name}_sched", sched + if sched_config.sched == 'warmup-anneal': + scheduler = LinearWarmupAndCosineAnneal( + optim, + warm_up=sched_config.warmup_iters / num_iters, + T_max=num_iters, + last_epoch=last_iter, + ) + elif sched_config.sched == 'linear': + scheduler = LinearLR( + optim, + num_epochs=num_iters, + last_epoch=last_iter + ) + elif sched_config.sched in {None, '', 'const'}: + scheduler = None + else: + raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'") + + yield f"{optim_name}_sched", scheduler def _custom_init_fn(self, config: BaseConfig): pass diff --git a/simclr/main.py b/simclr/main.py index b170355..b91b42b 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -131,11 +131,6 @@ class SimCLRConfig(BaseConfig): betas: tuple[float, float] weight_decay: float - @dataclass - class SchedConfig(BaseConfig.SchedConfig): - sched: str | None - warmup_iters: int - class SimCLRTrainer(Trainer): def __init__(self, hid_dim, out_dim, **kwargs): @@ -254,35 +249,6 @@ class SimCLRTrainer(Trainer): yield f"{model_name}_optim", optimizer - @staticmethod - def _configure_scheduler( - optims: Iterable[tuple[str, torch.optim.Optimizer]], - last_iter: int, - num_iters: int, - sched_config: SimCLRConfig.SchedConfig, - ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] - | tuple[str, None]]: - for optim_name, optim in optims: - if sched_config.sched == 'warmup-anneal': - scheduler = LinearWarmupAndCosineAnneal( - optim, - warm_up=sched_config.warmup_iters / num_iters, - T_max=num_iters, - last_epoch=last_iter, - ) - elif sched_config.sched == 'linear': - scheduler = LinearLR( - optim, - num_epochs=num_iters, - last_epoch=last_iter - ) - elif sched_config.sched in {None, '', 'const'}: - scheduler = None - else: - raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'") - - yield f"{optim_name}_sched", scheduler - def _custom_init_fn(self, config: SimCLRConfig): self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o for n, o in self.optims.items()} diff --git a/supervised/baseline.py b/supervised/baseline.py index 9ad107c..9c18f9f 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,10 +1,10 @@ import argparse import os -import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable +import sys import torch import yaml from torch.utils.data import Dataset @@ -15,7 +15,6 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) from libs.datautils import Clip -from libs.schedulers import LinearLR from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers from models import CIFARResNet50 @@ -103,10 +102,6 @@ class SupBaselineConfig(BaseConfig): betas: tuple[float, float] | None weight_decay: float - @dataclass - class SchedConfig(BaseConfig.SchedConfig): - sched: str | None - class SupBaselineTrainer(Trainer): def __init__(self, **kwargs): @@ -201,23 +196,6 @@ class SupBaselineTrainer(Trainer): yield f"{model_name}_optim", optimizer - @staticmethod - def _configure_scheduler( - optims: Iterable[tuple[str, torch.optim.Optimizer]], - last_iter: int, - num_iters: int, - sched_config: SupBaselineConfig.SchedConfig - ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] - | tuple[str, None]]: - for optim_name, optim in optims: - if sched_config.sched == 'linear': - sched = LinearLR(optim, num_iters, last_epoch=last_iter) - elif sched_config.sched is None: - sched = None - else: - raise NotImplementedError(f"Unimplemented scheduler: {sched_config.sched}") - yield f"{optim_name}_sched", sched - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): model = self.models['model'] optim = self.optims['model_optim'] -- cgit v1.2.3