From e120a99556deb51edbe66ab1ab46009704f9d898 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 25 Aug 2022 09:50:54 +0800 Subject: Fix optimizer checkpoint loading problem --- libs/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/utils.py b/libs/utils.py index c237a77..46093c5 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -212,10 +212,10 @@ class Trainer(ABC): for module_name in modules.keys(): module_state_dict = checkpoint[f"{module_name}_state_dict"] module = modules[module_name] - if isinstance(module, nn.Module): - module.load_state_dict(module_state_dict) - else: + if isinstance(module, torch.Tensor): module.data = module_state_dict + else: + module.load_state_dict(module_state_dict) last_metrics = {k: v for k, v in checkpoint.items() if not k.endswith('state_dict')} -- cgit v1.2.3