aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-25 09:50:54 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-25 09:50:54 +0800
commite120a99556deb51edbe66ab1ab46009704f9d898 (patch)
tree2d3d8c522ef2b3aca0558918ee2f74b946a7ab56
parent49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 (diff)
Fix optimizer checkpoint loading problem
-rw-r--r--libs/utils.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py
index c237a77..46093c5 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -212,10 +212,10 @@ class Trainer(ABC):
for module_name in modules.keys():
module_state_dict = checkpoint[f"{module_name}_state_dict"]
module = modules[module_name]
- if isinstance(module, nn.Module):
- module.load_state_dict(module_state_dict)
- else:
+ if isinstance(module, torch.Tensor):
module.data = module_state_dict
+ else:
+ module.load_state_dict(module_state_dict)
last_metrics = {k: v for k, v in checkpoint.items()
if not k.endswith('state_dict')}