aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 21:27:39 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 21:27:39 +0800
commit2acdd15d0b7c685926e23957ff0c07dacd12cf5b (patch)
tree388eb9e89986ba7dd3c4249e867cecf076135f06 /libs/utils.py
parente7313d916d783744012ac7bb3011469d72803d25 (diff)
Make epoch scheduler step after epoch instead of iter
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py19
1 files changed, 8 insertions, 11 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 2f73705..63ea116 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -82,21 +82,18 @@ class Trainer(ABC):
)
if last_metrics is None:
- last_iter = -1
+ last_step = -1
elif isinstance(last_metrics, BaseEpochLogRecord):
- last_iter = last_metrics.epoch
+ last_step = last_metrics.epoch
elif isinstance(last_metrics, BaseBatchLogRecord):
- last_iter = last_metrics.global_batch
+ last_step = last_metrics.global_batch
else:
raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'")
- if not inf_mode:
- num_iters *= len(train_loader)
- config.sched_config.warmup_iters *= len(train_loader) # FIXME: a little bit hacky here
scheds = dict(self._configure_scheduler(
- optims.items(), last_iter, num_iters, config.sched_config,
+ optims.items(), last_step, num_iters, config.sched_config,
))
- self.restore_iter = last_iter + 1
+ self.restore_iter = last_step + 1
self.train_loader = train_loader
self.test_loader = test_loader
self.models = models
@@ -226,7 +223,7 @@ class Trainer(ABC):
@staticmethod
def _configure_scheduler(
optims: Iterable[tuple[str, torch.optim.Optimizer]],
- last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
+ last_step: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
| tuple[str, None]]:
for optim_name, optim in optims:
@@ -235,13 +232,13 @@ class Trainer(ABC):
optim,
warm_up=sched_config.warmup_iters / num_iters,
T_max=num_iters,
- last_epoch=last_iter,
+ last_epoch=last_step,
)
elif sched_config.sched == 'linear':
scheduler = LinearLR(
optim,
num_epochs=num_iters,
- last_epoch=last_iter
+ last_epoch=last_step
)
elif sched_config.sched in {None, '', 'const'}:
scheduler = None