aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libs/logging.py3
-rw-r--r--libs/utils.py9
-rw-r--r--posrecon/models.py46
-rw-r--r--simclr/evaluate.py8
4 files changed, 58 insertions, 8 deletions
diff --git a/libs/logging.py b/libs/logging.py
index 3969ffa..6dfe5f9 100644
--- a/libs/logging.py
+++ b/libs/logging.py
@@ -113,7 +113,8 @@ def tensorboard_logger(function):
for metric_name, metric_value in metrics.__dict__.items():
if metric_name not in metrics_exclude:
if isinstance(metric_value, float):
- logger.add_scalar(metric_name, metric_value, global_step + 1)
+ logger.add_scalar(metric_name.replace('_', '/', 1),
+ metric_value, global_step + 1)
else:
NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
return loggers, metrics
diff --git a/libs/utils.py b/libs/utils.py
index 63ea116..c237a77 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import Iterable, Callable
import torch
+from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.utils.tensorboard import SummaryWriter
@@ -94,6 +95,7 @@ class Trainer(ABC):
))
self.restore_iter = last_step + 1
+ self.num_iters = num_iters
self.train_loader = train_loader
self.test_loader = test_loader
self.models = models
@@ -209,7 +211,11 @@ class Trainer(ABC):
checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint))
for module_name in modules.keys():
module_state_dict = checkpoint[f"{module_name}_state_dict"]
- modules[module_name].load_state_dict(module_state_dict)
+ module = modules[module_name]
+ if isinstance(module, nn.Module):
+ module.load_state_dict(module_state_dict)
+ else:
+ module.data = module_state_dict
last_metrics = {k: v for k, v in checkpoint.items()
if not k.endswith('state_dict')}
@@ -260,6 +266,7 @@ class Trainer(ABC):
os.makedirs(self._checkpoint_dir, exist_ok=True)
checkpoint_name = os.path.join(self._checkpoint_dir, f"{metrics.epoch:06d}.pt")
models_state_dict = {f"{model_name}_state_dict": model.state_dict()
+ if isinstance(model, nn.Module) else model.data
for model_name, model in self.models.items()}
optims_state_dict = {f"{optim_name}_state_dict": optim.state_dict()
for optim_name, optim in self.optims.items()}
diff --git a/posrecon/models.py b/posrecon/models.py
index 9e8cb22..7a2ce09 100644
--- a/posrecon/models.py
+++ b/posrecon/models.py
@@ -1,10 +1,18 @@
+from pathlib import Path
+
import math
+import sys
import torch
from timm.models.helpers import named_apply, checkpoint_seq
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
from torch import nn
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
+from simclr.models import SimCLRBase
+
class ShuffledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
@@ -135,7 +143,7 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer):
x = self.blocks(x)
x = self.norm(x)
if probe:
- return x, patch_embed
+ return x, patch_embed.detach().clone()
else:
return x
@@ -145,8 +153,42 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer):
if probe:
features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe)
x = self.forward_head(features)
- return x, features, patch_embed
+ return x, features.detach().clone(), patch_embed
else:
features = self.forward_features(x, pos_embed, visible_indices, probe)
x = self.forward_head(features)
return x
+
+
+class SimCLRPosRecon(SimCLRBase):
+ def __init__(
+ self,
+ vit: MaskedShuffledVisionTransformer,
+ hidden_dim: int = 2048,
+ probe: bool = False,
+ *args, **kwargs
+ ):
+ super(SimCLRPosRecon, self).__init__(vit, hidden_dim, *args, **kwargs)
+ self.hidden_dim = hidden_dim
+ self.probe = probe
+
+ def forward(self, x, pos_embed=None, visible_indices=None):
+ if self.probe:
+ output, features, patch_embed = self.backbone(x, pos_embed, visible_indices, True)
+ else:
+ output = self.backbone(x, pos_embed, visible_indices, False)
+ h = output[:, :self.hidden_dim]
+ flatten_pos_embed = output[:, self.hidden_dim:]
+ if self.pretrain:
+ z = self.projector(h)
+ if self.probe:
+ return z, flatten_pos_embed, h.detach().clone(), features, patch_embed
+ else:
+ return z, flatten_pos_embed
+ else:
+ return h
+
+
+def simclr_pos_recon_vit(vit_config: dict, *args, **kwargs):
+ vit = MaskedShuffledVisionTransformer(**vit_config)
+ return SimCLRPosRecon(vit, *args, **kwargs)
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index f18c417..1abb5ce 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -18,7 +18,7 @@ from libs.optimizers import LARS
from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
from libs.utils import BaseConfig
from simclr.main import SimCLRTrainer, SimCLRConfig
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
+from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50
def parse_args_and_config():
@@ -172,9 +172,9 @@ class SimCLREvalTrainer(SimCLRTrainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
if self.encoder == 'resnet':
- backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ backbone = cifar_simclr_resnet50(self.hid_dim, pretrain=False)
elif self.encoder == 'vit':
- backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False)
+ backbone = cifar_simclr_vit_tiny(self.hid_dim, pretrain=False)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
if dataset in {'cifar10', 'cifar'}:
@@ -183,7 +183,7 @@ class SimCLREvalTrainer(SimCLRTrainer):
classifier = torch.nn.Linear(self.hid_dim, 100)
elif dataset in {'imagenet1k', 'imagenet'}:
if self.encoder == 'resnet':
- backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ backbone = imagenet_simclr_resnet50(self.hid_dim, pretrain=False)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
classifier = torch.nn.Linear(self.hid_dim, 1000)