aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 15:17:06 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 15:17:06 +0800
commitf9b51f152e94896a122325fa2c48b0ce10c881c5 (patch)
treef777f3f9451318c0ed16428b8302fc3f02a08bb6
parentebb2f93ac01f40d00968daaf9a2ad96c24ce7ab3 (diff)
Add default schedulers
-rw-r--r--libs/utils.py24
-rw-r--r--simclr/main.py34
-rw-r--r--supervised/baseline.py24
3 files changed, 22 insertions, 60 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 5e8529b..c019ba9 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \
init_csv_logger, csv_logger, tensorboard_logger
+from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
@dataclass
@@ -33,6 +34,7 @@ class BaseConfig:
@dataclass
class SchedConfig:
sched: None
+ warmup_iters: int
dataset_config: DatasetConfig
dataloader_config: DataLoaderConfig
@@ -221,15 +223,31 @@ class Trainer(ABC):
return last_metrics
@staticmethod
- @abstractmethod
def _configure_scheduler(
optims: Iterable[tuple[str, torch.optim.Optimizer]],
last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
| tuple[str, None]]:
for optim_name, optim in optims:
- sched = torch.optim.lr_scheduler._LRScheduler(optim, -1)
- yield f"{optim_name}_sched", sched
+ if sched_config.sched == 'warmup-anneal':
+ scheduler = LinearWarmupAndCosineAnneal(
+ optim,
+ warm_up=sched_config.warmup_iters / num_iters,
+ T_max=num_iters,
+ last_epoch=last_iter,
+ )
+ elif sched_config.sched == 'linear':
+ scheduler = LinearLR(
+ optim,
+ num_epochs=num_iters,
+ last_epoch=last_iter
+ )
+ elif sched_config.sched in {None, '', 'const'}:
+ scheduler = None
+ else:
+ raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'")
+
+ yield f"{optim_name}_sched", scheduler
def _custom_init_fn(self, config: BaseConfig):
pass
diff --git a/simclr/main.py b/simclr/main.py
index b170355..b91b42b 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -131,11 +131,6 @@ class SimCLRConfig(BaseConfig):
betas: tuple[float, float]
weight_decay: float
- @dataclass
- class SchedConfig(BaseConfig.SchedConfig):
- sched: str | None
- warmup_iters: int
-
class SimCLRTrainer(Trainer):
def __init__(self, hid_dim, out_dim, **kwargs):
@@ -254,35 +249,6 @@ class SimCLRTrainer(Trainer):
yield f"{model_name}_optim", optimizer
- @staticmethod
- def _configure_scheduler(
- optims: Iterable[tuple[str, torch.optim.Optimizer]],
- last_iter: int,
- num_iters: int,
- sched_config: SimCLRConfig.SchedConfig,
- ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
- | tuple[str, None]]:
- for optim_name, optim in optims:
- if sched_config.sched == 'warmup-anneal':
- scheduler = LinearWarmupAndCosineAnneal(
- optim,
- warm_up=sched_config.warmup_iters / num_iters,
- T_max=num_iters,
- last_epoch=last_iter,
- )
- elif sched_config.sched == 'linear':
- scheduler = LinearLR(
- optim,
- num_epochs=num_iters,
- last_epoch=last_iter
- )
- elif sched_config.sched in {None, '', 'const'}:
- scheduler = None
- else:
- raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'")
-
- yield f"{optim_name}_sched", scheduler
-
def _custom_init_fn(self, config: SimCLRConfig):
self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o
for n, o in self.optims.items()}
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 9ad107c..9c18f9f 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,10 +1,10 @@
import argparse
import os
-import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
+import sys
import torch
import yaml
from torch.utils.data import Dataset
@@ -15,7 +15,6 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)
from libs.datautils import Clip
-from libs.schedulers import LinearLR
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers
from models import CIFARResNet50
@@ -103,10 +102,6 @@ class SupBaselineConfig(BaseConfig):
betas: tuple[float, float] | None
weight_decay: float
- @dataclass
- class SchedConfig(BaseConfig.SchedConfig):
- sched: str | None
-
class SupBaselineTrainer(Trainer):
def __init__(self, **kwargs):
@@ -201,23 +196,6 @@ class SupBaselineTrainer(Trainer):
yield f"{model_name}_optim", optimizer
- @staticmethod
- def _configure_scheduler(
- optims: Iterable[tuple[str, torch.optim.Optimizer]],
- last_iter: int,
- num_iters: int,
- sched_config: SupBaselineConfig.SchedConfig
- ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
- | tuple[str, None]]:
- for optim_name, optim in optims:
- if sched_config.sched == 'linear':
- sched = LinearLR(optim, num_iters, last_epoch=last_iter)
- elif sched_config.sched is None:
- sched = None
- else:
- raise NotImplementedError(f"Unimplemented scheduler: {sched_config.sched}")
- yield f"{optim_name}_sched", sched
-
def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
model = self.models['model']
optim = self.optims['model_optim']