aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 21:27:39 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 21:27:39 +0800
commit2acdd15d0b7c685926e23957ff0c07dacd12cf5b (patch)
tree388eb9e89986ba7dd3c4249e867cecf076135f06
parente7313d916d783744012ac7bb3011469d72803d25 (diff)
Make epoch scheduler step after epoch instead of iter
-rw-r--r--libs/utils.py19
-rw-r--r--simclr/evaluate.py23
-rw-r--r--supervised/baseline.py23
3 files changed, 30 insertions, 35 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 2f73705..63ea116 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -82,21 +82,18 @@ class Trainer(ABC):
)
if last_metrics is None:
- last_iter = -1
+ last_step = -1
elif isinstance(last_metrics, BaseEpochLogRecord):
- last_iter = last_metrics.epoch
+ last_step = last_metrics.epoch
elif isinstance(last_metrics, BaseBatchLogRecord):
- last_iter = last_metrics.global_batch
+ last_step = last_metrics.global_batch
else:
raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'")
- if not inf_mode:
- num_iters *= len(train_loader)
- config.sched_config.warmup_iters *= len(train_loader) # FIXME: a little bit hacky here
scheds = dict(self._configure_scheduler(
- optims.items(), last_iter, num_iters, config.sched_config,
+ optims.items(), last_step, num_iters, config.sched_config,
))
- self.restore_iter = last_iter + 1
+ self.restore_iter = last_step + 1
self.train_loader = train_loader
self.test_loader = test_loader
self.models = models
@@ -226,7 +223,7 @@ class Trainer(ABC):
@staticmethod
def _configure_scheduler(
optims: Iterable[tuple[str, torch.optim.Optimizer]],
- last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
+ last_step: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
| tuple[str, None]]:
for optim_name, optim in optims:
@@ -235,13 +232,13 @@ class Trainer(ABC):
optim,
warm_up=sched_config.warmup_iters / num_iters,
T_max=num_iters,
- last_epoch=last_iter,
+ last_epoch=last_step,
)
elif sched_config.sched == 'linear':
scheduler = LinearLR(
optim,
num_epochs=num_iters,
- last_epoch=last_iter
+ last_epoch=last_step
)
elif sched_config.sched in {None, '', 'const'}:
scheduler = None
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index a8005c4..0f26e16 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -222,18 +222,17 @@ class SimCLREvalTrainer(SimCLRTrainer):
batch, num_batches, global_batch, iter_, num_iters,
optim_c.param_groups[0]['lr'], train_loss.item()
))
- if batch == loader_size - 1:
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters,
- eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- if sched_b is not None and self.finetune:
- sched_b.step()
- if sched_c is not None:
- sched_c.step()
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ if sched_b is not None and self.finetune:
+ sched_b.step()
+ if sched_c is not None:
+ sched_c.step()
def eval(self, loss_fn: Callable, device: torch.device):
backbone, classifier = self.models.values()
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 6072c10..4be1c97 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -231,18 +231,17 @@ class SupBaselineTrainer(Trainer):
batch, num_batches, global_batch, iter_, num_iters,
optim.param_groups[0]['lr'], train_loss.item()
))
- if batch == loader_size - 1:
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters,
- eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- # Step after save checkpoint, otherwise the schedular will
- # one iter ahead after restore
- if sched is not None:
- sched.step()
+ metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
+ eval_loss = metrics[0].item()
+ eval_accuracy = metrics[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ # Step after save checkpoint, otherwise the schedular will
+ # one iter ahead after restore
+ if sched is not None:
+ sched.step()
def eval(self, loss_fn: Callable, device: torch.device):
model = self.models['model']