diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 21:27:39 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 21:27:39 +0800 |
commit | 2acdd15d0b7c685926e23957ff0c07dacd12cf5b (patch) | |
tree | 388eb9e89986ba7dd3c4249e867cecf076135f06 /supervised/baseline.py | |
parent | e7313d916d783744012ac7bb3011469d72803d25 (diff) |
Make epoch scheduler step after epoch instead of iter
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 6072c10..4be1c97 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -231,18 +231,17 @@ class SupBaselineTrainer(Trainer): batch, num_batches, global_batch, iter_, num_iters, optim.param_groups[0]['lr'], train_loss.item() )) - if batch == loader_size - 1: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - # Step after save checkpoint, otherwise the schedular will - # one iter ahead after restore - if sched is not None: - sched.step() + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + # Step after save checkpoint, otherwise the schedular will + # one iter ahead after restore + if sched is not None: + sched.step() def eval(self, loss_fn: Callable, device: torch.device): model = self.models['model'] |