diff options
Diffstat (limited to 'libs/utils.py')
-rw-r--r-- | libs/utils.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py index c237a77..46093c5 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -212,10 +212,10 @@ class Trainer(ABC): for module_name in modules.keys(): module_state_dict = checkpoint[f"{module_name}_state_dict"] module = modules[module_name] - if isinstance(module, nn.Module): - module.load_state_dict(module_state_dict) - else: + if isinstance(module, torch.Tensor): module.data = module_state_dict + else: + module.load_state_dict(module_state_dict) last_metrics = {k: v for k, v in checkpoint.items() if not k.endswith('state_dict')} |