aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index db93304..6072c10 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -231,15 +231,18 @@ class SupBaselineTrainer(Trainer):
batch, num_batches, global_batch, iter_, num_iters,
optim.param_groups[0]['lr'], train_loss.item()
))
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- # Step after save checkpoint, otherwise the schedular will one iter ahead after restore
- if sched is not None:
- sched.step()
+ if batch == loader_size - 1:
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ # Step after save checkpoint, otherwise the schedular will
+ # one iter ahead after restore
+ if sched is not None:
+ sched.step()
def eval(self, loss_fn: Callable, device: torch.device):
model = self.models['model']