From e7313d916d783744012ac7bb3011469d72803d25 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Aug 2022 19:54:37 +0800 Subject: Fix epoch scheduler problem --- libs/utils.py | 1 + simclr/evaluate.py | 22 ++++++++++++---------- supervised/baseline.py | 21 ++++++++++++--------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/libs/utils.py b/libs/utils.py index c019ba9..2f73705 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -91,6 +91,7 @@ class Trainer(ABC): raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'") if not inf_mode: num_iters *= len(train_loader) + config.sched_config.warmup_iters *= len(train_loader) # FIXME: a little bit hacky here scheds = dict(self._configure_scheduler( optims.items(), last_iter, num_iters, config.sched_config, )) diff --git a/simclr/evaluate.py b/simclr/evaluate.py index f4a8fda..a8005c4 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -222,16 +222,18 @@ class SimCLREvalTrainer(SimCLRTrainer): batch, num_batches, global_batch, iter_, num_iters, optim_c.param_groups[0]['lr'], train_loss.item() )) - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - if sched_b is not None and self.finetune: - sched_b.step() - if sched_c is not None: - sched_c.step() + if batch == loader_size - 1: + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + if sched_b is not None and self.finetune: + sched_b.step() + if sched_c is not None: + sched_c.step() def eval(self, loss_fn: Callable, device: torch.device): backbone, classifier = self.models.values() diff --git a/supervised/baseline.py b/supervised/baseline.py index db93304..6072c10 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -231,15 +231,18 @@ class SupBaselineTrainer(Trainer): batch, num_batches, global_batch, iter_, num_iters, optim.param_groups[0]['lr'], train_loss.item() )) - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - # Step after save checkpoint, otherwise the schedular will one iter ahead after restore - if sched is not None: - sched.step() + if batch == loader_size - 1: + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + # Step after save checkpoint, otherwise the schedular will + # one iter ahead after restore + if sched is not None: + sched.step() def eval(self, loss_fn: Callable, device: torch.device): model = self.models['model'] -- cgit v1.2.3