diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-12 18:17:14 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-12 18:17:14 +0800 |
commit | f838e1e40ef365406be6c853b94e596849f842ac (patch) | |
tree | cfcd0a4c2a64b118117b2dc3610983cd6a507877 | |
parent | 3d57389a007e195415c4e7708d5a7a74860c194f (diff) |
Make SimCLR evaluation less frequent
-rw-r--r-- | simclr/evaluate.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py index 0f26e16..5c41b84 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -222,17 +222,18 @@ class SimCLREvalTrainer(SimCLRTrainer): batch, num_batches, global_batch, iter_, num_iters, optim_c.param_groups[0]['lr'], train_loss.item() )) - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - if sched_b is not None and self.finetune: - sched_b.step() - if sched_c is not None: - sched_c.step() + if (iter_ + 1) % (num_iters // 10) == 0: + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + if sched_b is not None and self.finetune: + sched_b.step() + if sched_c is not None: + sched_c.step() def eval(self, loss_fn: Callable, device: torch.device): backbone, classifier = self.models.values() |