diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 24 |
1 files changed, 1 insertions, 23 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 9ad107c..9c18f9f 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,10 +1,10 @@ import argparse import os -import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable +import sys import torch import yaml from torch.utils.data import Dataset @@ -15,7 +15,6 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) from libs.datautils import Clip -from libs.schedulers import LinearLR from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers from models import CIFARResNet50 @@ -103,10 +102,6 @@ class SupBaselineConfig(BaseConfig): betas: tuple[float, float] | None weight_decay: float - @dataclass - class SchedConfig(BaseConfig.SchedConfig): - sched: str | None - class SupBaselineTrainer(Trainer): def __init__(self, **kwargs): @@ -201,23 +196,6 @@ class SupBaselineTrainer(Trainer): yield f"{model_name}_optim", optimizer - @staticmethod - def _configure_scheduler( - optims: Iterable[tuple[str, torch.optim.Optimizer]], - last_iter: int, - num_iters: int, - sched_config: SupBaselineConfig.SchedConfig - ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] - | tuple[str, None]]: - for optim_name, optim in optims: - if sched_config.sched == 'linear': - sched = LinearLR(optim, num_iters, last_epoch=last_iter) - elif sched_config.sched is None: - sched = None - else: - raise NotImplementedError(f"Unimplemented scheduler: {sched_config.sched}") - yield f"{optim_name}_sched", sched - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): model = self.models['model'] optim = self.optims['model_optim'] |