aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
Diffstat (limited to 'simclr')
-rw-r--r--simclr/main.py34
1 files changed, 0 insertions, 34 deletions
diff --git a/simclr/main.py b/simclr/main.py
index b170355..b91b42b 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -131,11 +131,6 @@ class SimCLRConfig(BaseConfig):
betas: tuple[float, float]
weight_decay: float
- @dataclass
- class SchedConfig(BaseConfig.SchedConfig):
- sched: str | None
- warmup_iters: int
-
class SimCLRTrainer(Trainer):
def __init__(self, hid_dim, out_dim, **kwargs):
@@ -254,35 +249,6 @@ class SimCLRTrainer(Trainer):
yield f"{model_name}_optim", optimizer
- @staticmethod
- def _configure_scheduler(
- optims: Iterable[tuple[str, torch.optim.Optimizer]],
- last_iter: int,
- num_iters: int,
- sched_config: SimCLRConfig.SchedConfig,
- ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
- | tuple[str, None]]:
- for optim_name, optim in optims:
- if sched_config.sched == 'warmup-anneal':
- scheduler = LinearWarmupAndCosineAnneal(
- optim,
- warm_up=sched_config.warmup_iters / num_iters,
- T_max=num_iters,
- last_epoch=last_iter,
- )
- elif sched_config.sched == 'linear':
- scheduler = LinearLR(
- optim,
- num_epochs=num_iters,
- last_epoch=last_iter
- )
- elif sched_config.sched in {None, '', 'const'}:
- scheduler = None
- else:
- raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'")
-
- yield f"{optim_name}_sched", scheduler
-
def _custom_init_fn(self, config: SimCLRConfig):
self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o
for n, o in self.optims.items()}