diff options
-rw-r--r-- | libs/utils.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/libs/utils.py b/libs/utils.py index 23964bc..767adfc 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -67,13 +67,10 @@ class Trainer(ABC): if last_metrics is None: last_iter = -1 - restore_iter = 0 elif isinstance(last_metrics, BaseEpochLogRecord): - last_iter = last_metrics.epoch * len(train_loader) - 1 - restore_iter = last_metrics.epoch + last_iter = last_metrics.epoch elif isinstance(last_metrics, BaseBatchLogRecord): - last_iter = last_metrics.global_batch - 1 - restore_iter = last_metrics.global_batch + last_iter = last_metrics.global_batch else: raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'") if not inf_mode: @@ -84,7 +81,7 @@ class Trainer(ABC): self._custom_init_fn(config) - self.restore_iter = restore_iter + self.restore_iter = last_iter + 1 self.train_loader = train_loader self.test_loader = test_loader self.models = models |