aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py
index c237a77..46093c5 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -212,10 +212,10 @@ class Trainer(ABC):
for module_name in modules.keys():
module_state_dict = checkpoint[f"{module_name}_state_dict"]
module = modules[module_name]
- if isinstance(module, nn.Module):
- module.load_state_dict(module_state_dict)
- else:
+ if isinstance(module, torch.Tensor):
module.data = module_state_dict
+ else:
+ module.load_state_dict(module_state_dict)
last_metrics = {k: v for k, v in checkpoint.items()
if not k.endswith('state_dict')}