aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-07-14 09:56:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-07-14 09:56:08 +0800
commitd547f7384afc61661d8df09763c8102a916143ed (patch)
tree8009c90a3b4f6e9cf2125e0ad0cf35dc3ba91eae
parent80f132510161cd8ae75ba153d24030a27b772815 (diff)
Correct last iter and restore iter
-rw-r--r--libs/utils.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 23964bc..767adfc 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -67,13 +67,10 @@ class Trainer(ABC):
if last_metrics is None:
last_iter = -1
- restore_iter = 0
elif isinstance(last_metrics, BaseEpochLogRecord):
- last_iter = last_metrics.epoch * len(train_loader) - 1
- restore_iter = last_metrics.epoch
+ last_iter = last_metrics.epoch
elif isinstance(last_metrics, BaseBatchLogRecord):
- last_iter = last_metrics.global_batch - 1
- restore_iter = last_metrics.global_batch
+ last_iter = last_metrics.global_batch
else:
raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'")
if not inf_mode:
@@ -84,7 +81,7 @@ class Trainer(ABC):
self._custom_init_fn(config)
- self.restore_iter = restore_iter
+ self.restore_iter = last_iter + 1
self.train_loader = train_loader
self.test_loader = test_loader
self.models = models