diff options
Diffstat (limited to 'libs')
-rw-r--r-- | libs/utils.py | 19 |
1 files changed, 8 insertions, 11 deletions
diff --git a/libs/utils.py b/libs/utils.py index 2f73705..63ea116 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -82,21 +82,18 @@ class Trainer(ABC): ) if last_metrics is None: - last_iter = -1 + last_step = -1 elif isinstance(last_metrics, BaseEpochLogRecord): - last_iter = last_metrics.epoch + last_step = last_metrics.epoch elif isinstance(last_metrics, BaseBatchLogRecord): - last_iter = last_metrics.global_batch + last_step = last_metrics.global_batch else: raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'") - if not inf_mode: - num_iters *= len(train_loader) - config.sched_config.warmup_iters *= len(train_loader) # FIXME: a little bit hacky here scheds = dict(self._configure_scheduler( - optims.items(), last_iter, num_iters, config.sched_config, + optims.items(), last_step, num_iters, config.sched_config, )) - self.restore_iter = last_iter + 1 + self.restore_iter = last_step + 1 self.train_loader = train_loader self.test_loader = test_loader self.models = models @@ -226,7 +223,7 @@ class Trainer(ABC): @staticmethod def _configure_scheduler( optims: Iterable[tuple[str, torch.optim.Optimizer]], - last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, + last_step: int, num_iters: int, sched_config: BaseConfig.SchedConfig, ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] | tuple[str, None]]: for optim_name, optim in optims: @@ -235,13 +232,13 @@ class Trainer(ABC): optim, warm_up=sched_config.warmup_iters / num_iters, T_max=num_iters, - last_epoch=last_iter, + last_epoch=last_step, ) elif sched_config.sched == 'linear': scheduler = LinearLR( optim, num_epochs=num_iters, - last_epoch=last_iter + last_epoch=last_step ) elif sched_config.sched in {None, '', 'const'}: scheduler = None |