diff options
Diffstat (limited to 'simclr/main.py')
-rw-r--r-- | simclr/main.py | 34 |
1 files changed, 0 insertions, 34 deletions
diff --git a/simclr/main.py b/simclr/main.py index b170355..b91b42b 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -131,11 +131,6 @@ class SimCLRConfig(BaseConfig): betas: tuple[float, float] weight_decay: float - @dataclass - class SchedConfig(BaseConfig.SchedConfig): - sched: str | None - warmup_iters: int - class SimCLRTrainer(Trainer): def __init__(self, hid_dim, out_dim, **kwargs): @@ -254,35 +249,6 @@ class SimCLRTrainer(Trainer): yield f"{model_name}_optim", optimizer - @staticmethod - def _configure_scheduler( - optims: Iterable[tuple[str, torch.optim.Optimizer]], - last_iter: int, - num_iters: int, - sched_config: SimCLRConfig.SchedConfig, - ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] - | tuple[str, None]]: - for optim_name, optim in optims: - if sched_config.sched == 'warmup-anneal': - scheduler = LinearWarmupAndCosineAnneal( - optim, - warm_up=sched_config.warmup_iters / num_iters, - T_max=num_iters, - last_epoch=last_iter, - ) - elif sched_config.sched == 'linear': - scheduler = LinearLR( - optim, - num_epochs=num_iters, - last_epoch=last_iter - ) - elif sched_config.sched in {None, '', 'const'}: - scheduler = None - else: - raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'") - - yield f"{optim_name}_sched", scheduler - def _custom_init_fn(self, config: SimCLRConfig): self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o for n, o in self.optims.items()} |