aboutsummaryrefslogtreecommitdiff
path: root/libs
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-21 17:28:07 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-21 17:28:07 +0800
commit49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 (patch)
tree6f6286045cd68de054a602587631a283c64aeb7d /libs
parent4c242c1383afb8072ce6d2904f51cdb005eced4c (diff)
Some modifications for PosRecon trainer
Diffstat (limited to 'libs')
-rw-r--r--libs/logging.py3
-rw-r--r--libs/utils.py9
2 files changed, 10 insertions, 2 deletions
diff --git a/libs/logging.py b/libs/logging.py
index 3969ffa..6dfe5f9 100644
--- a/libs/logging.py
+++ b/libs/logging.py
@@ -113,7 +113,8 @@ def tensorboard_logger(function):
for metric_name, metric_value in metrics.__dict__.items():
if metric_name not in metrics_exclude:
if isinstance(metric_value, float):
- logger.add_scalar(metric_name, metric_value, global_step + 1)
+ logger.add_scalar(metric_name.replace('_', '/', 1),
+ metric_value, global_step + 1)
else:
NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
return loggers, metrics
diff --git a/libs/utils.py b/libs/utils.py
index 63ea116..c237a77 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import Iterable, Callable
import torch
+from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.utils.tensorboard import SummaryWriter
@@ -94,6 +95,7 @@ class Trainer(ABC):
))
self.restore_iter = last_step + 1
+ self.num_iters = num_iters
self.train_loader = train_loader
self.test_loader = test_loader
self.models = models
@@ -209,7 +211,11 @@ class Trainer(ABC):
checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint))
for module_name in modules.keys():
module_state_dict = checkpoint[f"{module_name}_state_dict"]
- modules[module_name].load_state_dict(module_state_dict)
+ module = modules[module_name]
+ if isinstance(module, nn.Module):
+ module.load_state_dict(module_state_dict)
+ else:
+ module.data = module_state_dict
last_metrics = {k: v for k, v in checkpoint.items()
if not k.endswith('state_dict')}
@@ -260,6 +266,7 @@ class Trainer(ABC):
os.makedirs(self._checkpoint_dir, exist_ok=True)
checkpoint_name = os.path.join(self._checkpoint_dir, f"{metrics.epoch:06d}.pt")
models_state_dict = {f"{model_name}_state_dict": model.state_dict()
+ if isinstance(model, nn.Module) else model.data
for model_name, model in self.models.items()}
optims_state_dict = {f"{optim_name}_state_dict": optim.state_dict()
for optim_name, optim in self.optims.items()}