aboutsummaryrefslogtreecommitdiff
path: root/simclr/evaluate.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 19:54:37 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 19:54:37 +0800
commite7313d916d783744012ac7bb3011469d72803d25 (patch)
tree3b51f70c71cab6c0135909f04ae855936698d256 /simclr/evaluate.py
parent81597cdd0a55140f50b32b69507bfa5309b75f44 (diff)
Fix epoch scheduler problem
Diffstat (limited to 'simclr/evaluate.py')
-rw-r--r--simclr/evaluate.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index f4a8fda..a8005c4 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -222,16 +222,18 @@ class SimCLREvalTrainer(SimCLRTrainer):
batch, num_batches, global_batch, iter_, num_iters,
optim_c.param_groups[0]['lr'], train_loss.item()
))
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- if sched_b is not None and self.finetune:
- sched_b.step()
- if sched_c is not None:
- sched_c.step()
+ if batch == loader_size - 1:
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ if sched_b is not None and self.finetune:
+ sched_b.step()
+ if sched_c is not None:
+ sched_c.step()
def eval(self, loss_fn: Callable, device: torch.device):
backbone, classifier = self.models.values()