diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-21 17:28:07 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-21 17:28:07 +0800 |
commit | 49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 (patch) | |
tree | 6f6286045cd68de054a602587631a283c64aeb7d /posrecon/models.py | |
parent | 4c242c1383afb8072ce6d2904f51cdb005eced4c (diff) |
Some modifications for PosRecon trainer
Diffstat (limited to 'posrecon/models.py')
-rw-r--r-- | posrecon/models.py | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/posrecon/models.py b/posrecon/models.py index 9e8cb22..7a2ce09 100644 --- a/posrecon/models.py +++ b/posrecon/models.py @@ -1,10 +1,18 @@ +from pathlib import Path + import math +import sys import torch from timm.models.helpers import named_apply, checkpoint_seq from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit from torch import nn +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from simclr.models import SimCLRBase + class ShuffledVisionTransformer(VisionTransformer): def __init__(self, *args, **kwargs): @@ -135,7 +143,7 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer): x = self.blocks(x) x = self.norm(x) if probe: - return x, patch_embed + return x, patch_embed.detach().clone() else: return x @@ -145,8 +153,42 @@ class MaskedShuffledVisionTransformer(ShuffledVisionTransformer): if probe: features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe) x = self.forward_head(features) - return x, features, patch_embed + return x, features.detach().clone(), patch_embed else: features = self.forward_features(x, pos_embed, visible_indices, probe) x = self.forward_head(features) return x + + +class SimCLRPosRecon(SimCLRBase): + def __init__( + self, + vit: MaskedShuffledVisionTransformer, + hidden_dim: int = 2048, + probe: bool = False, + *args, **kwargs + ): + super(SimCLRPosRecon, self).__init__(vit, hidden_dim, *args, **kwargs) + self.hidden_dim = hidden_dim + self.probe = probe + + def forward(self, x, pos_embed=None, visible_indices=None): + if self.probe: + output, features, patch_embed = self.backbone(x, pos_embed, visible_indices, True) + else: + output = self.backbone(x, pos_embed, visible_indices, False) + h = output[:, :self.hidden_dim] + flatten_pos_embed = output[:, self.hidden_dim:] + if self.pretrain: + z = self.projector(h) + if self.probe: + return z, flatten_pos_embed, h.detach().clone(), features, patch_embed + else: + return z, flatten_pos_embed + else: + return h + + +def simclr_pos_recon_vit(vit_config: dict, *args, **kwargs): + vit = MaskedShuffledVisionTransformer(**vit_config) + return SimCLRPosRecon(vit, *args, **kwargs) |