aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 19:54:37 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 19:54:37 +0800
commite7313d916d783744012ac7bb3011469d72803d25 (patch)
tree3b51f70c71cab6c0135909f04ae855936698d256
parent81597cdd0a55140f50b32b69507bfa5309b75f44 (diff)
Fix epoch scheduler problem
-rw-r--r--libs/utils.py1
-rw-r--r--simclr/evaluate.py22
-rw-r--r--supervised/baseline.py21
3 files changed, 25 insertions, 19 deletions
diff --git a/libs/utils.py b/libs/utils.py
index c019ba9..2f73705 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -91,6 +91,7 @@ class Trainer(ABC):
raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'")
if not inf_mode:
num_iters *= len(train_loader)
+ config.sched_config.warmup_iters *= len(train_loader) # FIXME: a little bit hacky here
scheds = dict(self._configure_scheduler(
optims.items(), last_iter, num_iters, config.sched_config,
))
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index f4a8fda..a8005c4 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -222,16 +222,18 @@ class SimCLREvalTrainer(SimCLRTrainer):
batch, num_batches, global_batch, iter_, num_iters,
optim_c.param_groups[0]['lr'], train_loss.item()
))
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- if sched_b is not None and self.finetune:
- sched_b.step()
- if sched_c is not None:
- sched_c.step()
+ if batch == loader_size - 1:
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ if sched_b is not None and self.finetune:
+ sched_b.step()
+ if sched_c is not None:
+ sched_c.step()
def eval(self, loss_fn: Callable, device: torch.device):
backbone, classifier = self.models.values()
diff --git a/supervised/baseline.py b/supervised/baseline.py
index db93304..6072c10 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -231,15 +231,18 @@ class SupBaselineTrainer(Trainer):
batch, num_batches, global_batch, iter_, num_iters,
optim.param_groups[0]['lr'], train_loss.item()
))
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- # Step after save checkpoint, otherwise the schedular will one iter ahead after restore
- if sched is not None:
- sched.step()
+ if batch == loader_size - 1:
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ # Step after save checkpoint, otherwise the schedular will
+ # one iter ahead after restore
+ if sched is not None:
+ sched.step()
def eval(self, loss_fn: Callable, device: torch.device):
model = self.models['model']