diff options
-rw-r--r-- | libs/logging.py | 3 | ||||
-rw-r--r-- | libs/utils.py | 9 | ||||
-rw-r--r-- | posrecon/models.py | 46 | ||||
-rw-r--r-- | simclr/evaluate.py | 8 |
4 files changed, 58 insertions, 8 deletions
diff --git a/libs/logging.py b/libs/logging.py index 3969ffa..6dfe5f9 100644 --- a/libs/logging.py +++ b/libs/logging.py @@ -113,7 +113,8 @@ def tensorboard_logger(function): for metric_name, metric_value in metrics.__dict__.items(): if metric_name not in metrics_exclude: if isinstance(metric_value, float): - logger.add_scalar(metric_name, metric_value, global_step + 1) + logger.add_scalar(metric_name.replace('_', '/', 1), + metric_value, global_step + 1) else: NotImplementedError(f"Unsupported type: '{type(metric_value)}'") return loggers, metrics diff --git a/libs/utils.py b/libs/utils.py index 63ea116..c237a77 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Iterable, Callable import torch +from torch import nn from torch.backends import cudnn from torch.utils.data import Dataset, DataLoader, RandomSampler from torch.utils.tensorboard import SummaryWriter @@ -94,6 +95,7 @@ class Trainer(ABC): )) self.restore_iter = last_step + 1 + self.num_iters = num_iters self.train_loader = train_loader self.test_loader = test_loader self.models = models @@ -209,7 +211,11 @@ class Trainer(ABC): checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint)) for module_name in modules.keys(): module_state_dict = checkpoint[f"{module_name}_state_dict"] - modules[module_name].load_state_dict(module_state_dict) + module = modules[module_name] + if isinstance(module, nn.Module): + module.load_state_dict(module_state_dict) + else: + module.data = module_state_dict last_metrics = {k: v for k, v in checkpoint.items() if not k.endswith('state_dict')} @@ -260,6 +266,7 @@ class Trainer(ABC): os.makedirs(self._checkpoint_dir, exist_ok=True) checkpoint_name = os.path.join(self._checkpoint_dir, f"{metrics.epoch:06d}.pt") models_state_dict = {f"{model_name}_state_dict": model.state_dict() + if isinstance(model, nn.Module) else model.data for model_name, model in self.models.items()} optims_state_dict = {f"{optim_name}_state_dict": optim.state_dict() for optim_name, optim in self.optims.items()} diff --git a/posrecon/models.py b/posrecon/models.py index 9e8cb22..7a2ce09 100644 --- a/posrecon/models.py +++ b/posrecon/models.py @@ -1,10 +1,18 @@ +from pathlib import Path + import math +import sys import torch from timm.models.helpers import named_apply, checkpoint_seq from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit from torch import nn +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from simclr.models import SimCLRBase + class ShuffledVisionTransformer(VisionTransformer): def __init__(self, *args, **kwargs): @@ -135,7 +143,7 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer): x = self.blocks(x) x = self.norm(x) if probe: - return x, patch_embed + return x, patch_embed.detach().clone() else: return x @@ -145,8 +153,42 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer): if probe: features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe) x = self.forward_head(features) - return x, features, patch_embed + return x, features.detach().clone(), patch_embed else: features = self.forward_features(x, pos_embed, visible_indices, probe) x = self.forward_head(features) return x + + +class SimCLRPosRecon(SimCLRBase): + def __init__( + self, + vit: MaskedShuffledVisionTransformer, + hidden_dim: int = 2048, + probe: bool = False, + *args, **kwargs + ): + super(SimCLRPosRecon, self).__init__(vit, hidden_dim, *args, **kwargs) + self.hidden_dim = hidden_dim + self.probe = probe + + def forward(self, x, pos_embed=None, visible_indices=None): + if self.probe: + output, features, patch_embed = self.backbone(x, pos_embed, visible_indices, True) + else: + output = self.backbone(x, pos_embed, visible_indices, False) + h = output[:, :self.hidden_dim] + flatten_pos_embed = output[:, self.hidden_dim:] + if self.pretrain: + z = self.projector(h) + if self.probe: + return z, flatten_pos_embed, h.detach().clone(), features, patch_embed + else: + return z, flatten_pos_embed + else: + return h + + +def simclr_pos_recon_vit(vit_config: dict, *args, **kwargs): + vit = MaskedShuffledVisionTransformer(**vit_config) + return SimCLRPosRecon(vit, *args, **kwargs) diff --git a/simclr/evaluate.py b/simclr/evaluate.py index f18c417..1abb5ce 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -18,7 +18,7 @@ from libs.optimizers import LARS from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord from libs.utils import BaseConfig from simclr.main import SimCLRTrainer, SimCLRConfig -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny +from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50 def parse_args_and_config(): @@ -172,9 +172,9 @@ class SimCLREvalTrainer(SimCLRTrainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: if self.encoder == 'resnet': - backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False) + backbone = cifar_simclr_resnet50(self.hid_dim, pretrain=False) elif self.encoder == 'vit': - backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False) + backbone = cifar_simclr_vit_tiny(self.hid_dim, pretrain=False) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") if dataset in {'cifar10', 'cifar'}: @@ -183,7 +183,7 @@ class SimCLREvalTrainer(SimCLRTrainer): classifier = torch.nn.Linear(self.hid_dim, 100) elif dataset in {'imagenet1k', 'imagenet'}: if self.encoder == 'resnet': - backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False) + backbone = imagenet_simclr_resnet50(self.hid_dim, pretrain=False) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") classifier = torch.nn.Linear(self.hid_dim, 1000) |