aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 21:27:39 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 21:27:39 +0800
commit2acdd15d0b7c685926e23957ff0c07dacd12cf5b (patch)
tree388eb9e89986ba7dd3c4249e867cecf076135f06 /simclr
parente7313d916d783744012ac7bb3011469d72803d25 (diff)
Make epoch scheduler step after epoch instead of iter
Diffstat (limited to 'simclr')
-rw-r--r--simclr/evaluate.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index a8005c4..0f26e16 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -222,18 +222,17 @@ class SimCLREvalTrainer(SimCLRTrainer):
batch, num_batches, global_batch, iter_, num_iters,
optim_c.param_groups[0]['lr'], train_loss.item()
))
- if batch == loader_size - 1:
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters,
- eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- if sched_b is not None and self.finetune:
- sched_b.step()
- if sched_c is not None:
- sched_c.step()
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ if sched_b is not None and self.finetune:
+ sched_b.step()
+ if sched_c is not None:
+ sched_c.step()
def eval(self, loss_fn: Callable, device: torch.device):
backbone, classifier = self.models.values()