aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 23964bc..767adfc 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -67,13 +67,10 @@ class Trainer(ABC):
if last_metrics is None:
last_iter = -1
- restore_iter = 0
elif isinstance(last_metrics, BaseEpochLogRecord):
- last_iter = last_metrics.epoch * len(train_loader) - 1
- restore_iter = last_metrics.epoch
+ last_iter = last_metrics.epoch
elif isinstance(last_metrics, BaseBatchLogRecord):
- last_iter = last_metrics.global_batch - 1
- restore_iter = last_metrics.global_batch
+ last_iter = last_metrics.global_batch
else:
raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'")
if not inf_mode:
@@ -84,7 +81,7 @@ class Trainer(ABC):
self._custom_init_fn(config)
- self.restore_iter = restore_iter
+ self.restore_iter = last_iter + 1
self.train_loader = train_loader
self.test_loader = test_loader
self.models = models