diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-25 09:50:54 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-25 09:50:54 +0800 |
commit | e120a99556deb51edbe66ab1ab46009704f9d898 (patch) | |
tree | 2d3d8c522ef2b3aca0558918ee2f74b946a7ab56 /libs | |
parent | 49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 (diff) |
Fix optimizer checkpoint loading problem
Diffstat (limited to 'libs')
-rw-r--r-- | libs/utils.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py index c237a77..46093c5 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -212,10 +212,10 @@ class Trainer(ABC): for module_name in modules.keys(): module_state_dict = checkpoint[f"{module_name}_state_dict"] module = modules[module_name] - if isinstance(module, nn.Module): - module.load_state_dict(module_state_dict) - else: + if isinstance(module, torch.Tensor): module.data = module_state_dict + else: + module.load_state_dict(module_state_dict) last_metrics = {k: v for k, v in checkpoint.items() if not k.endswith('state_dict')} |