aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py24
1 files changed, 1 insertions, 23 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 9ad107c..9c18f9f 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,10 +1,10 @@
import argparse
import os
-import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
+import sys
import torch
import yaml
from torch.utils.data import Dataset
@@ -15,7 +15,6 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)
from libs.datautils import Clip
-from libs.schedulers import LinearLR
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers
from models import CIFARResNet50
@@ -103,10 +102,6 @@ class SupBaselineConfig(BaseConfig):
betas: tuple[float, float] | None
weight_decay: float
- @dataclass
- class SchedConfig(BaseConfig.SchedConfig):
- sched: str | None
-
class SupBaselineTrainer(Trainer):
def __init__(self, **kwargs):
@@ -201,23 +196,6 @@ class SupBaselineTrainer(Trainer):
yield f"{model_name}_optim", optimizer
- @staticmethod
- def _configure_scheduler(
- optims: Iterable[tuple[str, torch.optim.Optimizer]],
- last_iter: int,
- num_iters: int,
- sched_config: SupBaselineConfig.SchedConfig
- ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
- | tuple[str, None]]:
- for optim_name, optim in optims:
- if sched_config.sched == 'linear':
- sched = LinearLR(optim, num_iters, last_epoch=last_iter)
- elif sched_config.sched is None:
- sched = None
- else:
- raise NotImplementedError(f"Unimplemented scheduler: {sched_config.sched}")
- yield f"{optim_name}_sched", sched
-
def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
model = self.models['model']
optim = self.optims['model_optim']