diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-21 17:28:07 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-21 17:28:07 +0800 |
commit | 49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 (patch) | |
tree | 6f6286045cd68de054a602587631a283c64aeb7d /libs | |
parent | 4c242c1383afb8072ce6d2904f51cdb005eced4c (diff) |
Some modifications for PosRecon trainer
Diffstat (limited to 'libs')
-rw-r--r-- | libs/logging.py | 3 | ||||
-rw-r--r-- | libs/utils.py | 9 |
2 files changed, 10 insertions, 2 deletions
diff --git a/libs/logging.py b/libs/logging.py index 3969ffa..6dfe5f9 100644 --- a/libs/logging.py +++ b/libs/logging.py @@ -113,7 +113,8 @@ def tensorboard_logger(function): for metric_name, metric_value in metrics.__dict__.items(): if metric_name not in metrics_exclude: if isinstance(metric_value, float): - logger.add_scalar(metric_name, metric_value, global_step + 1) + logger.add_scalar(metric_name.replace('_', '/', 1), + metric_value, global_step + 1) else: NotImplementedError(f"Unsupported type: '{type(metric_value)}'") return loggers, metrics diff --git a/libs/utils.py b/libs/utils.py index 63ea116..c237a77 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Iterable, Callable import torch +from torch import nn from torch.backends import cudnn from torch.utils.data import Dataset, DataLoader, RandomSampler from torch.utils.tensorboard import SummaryWriter @@ -94,6 +95,7 @@ class Trainer(ABC): )) self.restore_iter = last_step + 1 + self.num_iters = num_iters self.train_loader = train_loader self.test_loader = test_loader self.models = models @@ -209,7 +211,11 @@ class Trainer(ABC): checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint)) for module_name in modules.keys(): module_state_dict = checkpoint[f"{module_name}_state_dict"] - modules[module_name].load_state_dict(module_state_dict) + module = modules[module_name] + if isinstance(module, nn.Module): + module.load_state_dict(module_state_dict) + else: + module.data = module_state_dict last_metrics = {k: v for k, v in checkpoint.items() if not k.endswith('state_dict')} @@ -260,6 +266,7 @@ class Trainer(ABC): os.makedirs(self._checkpoint_dir, exist_ok=True) checkpoint_name = os.path.join(self._checkpoint_dir, f"{metrics.epoch:06d}.pt") models_state_dict = {f"{model_name}_state_dict": model.state_dict() + if isinstance(model, nn.Module) else model.data for model_name, model in self.models.items()} optims_state_dict = {f"{optim_name}_state_dict": optim.state_dict() for optim_name, optim in self.optims.items()} |