diff options
-rw-r--r-- | libs/utils.py | 19 | ||||
-rw-r--r-- | simclr/evaluate.py | 23 | ||||
-rw-r--r-- | supervised/baseline.py | 23 |
3 files changed, 30 insertions, 35 deletions
diff --git a/libs/utils.py b/libs/utils.py index 2f73705..63ea116 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -82,21 +82,18 @@ class Trainer(ABC): ) if last_metrics is None: - last_iter = -1 + last_step = -1 elif isinstance(last_metrics, BaseEpochLogRecord): - last_iter = last_metrics.epoch + last_step = last_metrics.epoch elif isinstance(last_metrics, BaseBatchLogRecord): - last_iter = last_metrics.global_batch + last_step = last_metrics.global_batch else: raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'") - if not inf_mode: - num_iters *= len(train_loader) - config.sched_config.warmup_iters *= len(train_loader) # FIXME: a little bit hacky here scheds = dict(self._configure_scheduler( - optims.items(), last_iter, num_iters, config.sched_config, + optims.items(), last_step, num_iters, config.sched_config, )) - self.restore_iter = last_iter + 1 + self.restore_iter = last_step + 1 self.train_loader = train_loader self.test_loader = test_loader self.models = models @@ -226,7 +223,7 @@ class Trainer(ABC): @staticmethod def _configure_scheduler( optims: Iterable[tuple[str, torch.optim.Optimizer]], - last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, + last_step: int, num_iters: int, sched_config: BaseConfig.SchedConfig, ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] | tuple[str, None]]: for optim_name, optim in optims: @@ -235,13 +232,13 @@ class Trainer(ABC): optim, warm_up=sched_config.warmup_iters / num_iters, T_max=num_iters, - last_epoch=last_iter, + last_epoch=last_step, ) elif sched_config.sched == 'linear': scheduler = LinearLR( optim, num_epochs=num_iters, - last_epoch=last_iter + last_epoch=last_step ) elif sched_config.sched in {None, '', 'const'}: scheduler = None diff --git a/simclr/evaluate.py b/simclr/evaluate.py index a8005c4..0f26e16 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -222,18 +222,17 @@ class SimCLREvalTrainer(SimCLRTrainer): batch, num_batches, global_batch, iter_, num_iters, optim_c.param_groups[0]['lr'], train_loss.item() )) - if batch == loader_size - 1: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - if sched_b is not None and self.finetune: - sched_b.step() - if sched_c is not None: - sched_c.step() + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + if sched_b is not None and self.finetune: + sched_b.step() + if sched_c is not None: + sched_c.step() def eval(self, loss_fn: Callable, device: torch.device): backbone, classifier = self.models.values() diff --git a/supervised/baseline.py b/supervised/baseline.py index 6072c10..4be1c97 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -231,18 +231,17 @@ class SupBaselineTrainer(Trainer): batch, num_batches, global_batch, iter_, num_iters, optim.param_groups[0]['lr'], train_loss.item() )) - if batch == loader_size - 1: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - # Step after save checkpoint, otherwise the schedular will - # one iter ahead after restore - if sched is not None: - sched.step() + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + # Step after save checkpoint, otherwise the schedular will + # one iter ahead after restore + if sched is not None: + sched.step() def eval(self, loss_fn: Callable, device: torch.device): model = self.models['model'] |