aboutsummaryrefslogtreecommitdiff
path: root/posrecon/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'posrecon/models.py')
-rw-r--r--posrecon/models.py46
1 files changed, 44 insertions, 2 deletions
diff --git a/posrecon/models.py b/posrecon/models.py
index 9e8cb22..7a2ce09 100644
--- a/posrecon/models.py
+++ b/posrecon/models.py
@@ -1,10 +1,18 @@
+from pathlib import Path
+
import math
+import sys
import torch
from timm.models.helpers import named_apply, checkpoint_seq
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
from torch import nn
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
+from simclr.models import SimCLRBase
+
class ShuffledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
@@ -135,7 +143,7 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer):
x = self.blocks(x)
x = self.norm(x)
if probe:
- return x, patch_embed
+ return x, patch_embed.detach().clone()
else:
return x
@@ -145,8 +153,42 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer):
if probe:
features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe)
x = self.forward_head(features)
- return x, features, patch_embed
+ return x, features.detach().clone(), patch_embed
else:
features = self.forward_features(x, pos_embed, visible_indices, probe)
x = self.forward_head(features)
return x
+
+
+class SimCLRPosRecon(SimCLRBase):
+ def __init__(
+ self,
+ vit: MaskedShuffledVisionTransformer,
+ hidden_dim: int = 2048,
+ probe: bool = False,
+ *args, **kwargs
+ ):
+ super(SimCLRPosRecon, self).__init__(vit, hidden_dim, *args, **kwargs)
+ self.hidden_dim = hidden_dim
+ self.probe = probe
+
+ def forward(self, x, pos_embed=None, visible_indices=None):
+ if self.probe:
+ output, features, patch_embed = self.backbone(x, pos_embed, visible_indices, True)
+ else:
+ output = self.backbone(x, pos_embed, visible_indices, False)
+ h = output[:, :self.hidden_dim]
+ flatten_pos_embed = output[:, self.hidden_dim:]
+ if self.pretrain:
+ z = self.projector(h)
+ if self.probe:
+ return z, flatten_pos_embed, h.detach().clone(), features, patch_embed
+ else:
+ return z, flatten_pos_embed
+ else:
+ return h
+
+
+def simclr_pos_recon_vit(vit_config: dict, *args, **kwargs):
+ vit = MaskedShuffledVisionTransformer(**vit_config)
+ return SimCLRPosRecon(vit, *args, **kwargs)