From ea13f7e3fc3aec1b48d61896dcb3032897cc4b7a Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 19 Aug 2022 16:53:24 +0800 Subject: Add ViT with position shuffling/unshuffling and masking methods --- posrecon/models.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) (limited to 'posrecon') diff --git a/posrecon/models.py b/posrecon/models.py index 62cf322..9e8cb22 100644 --- a/posrecon/models.py +++ b/posrecon/models.py @@ -105,3 +105,48 @@ class ShuffledVisionTransformer(VisionTransformer): features = self.forward_features(x, pos_embed, probe) x = self.forward_head(features) return x + + +class MaskedShuffledVisionTransformer(ShuffledVisionTransformer): + def __init__(self, *args, **kwargs): + super(MaskedShuffledVisionTransformer, self).__init__(*args, **kwargs) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim) * 0.02) + trunc_normal_(self.mask_token, std=.02) + + def generate_masks(self, mask_rate=0.75): + nmasks = int(self.patch_embed.num_patches * mask_rate) + shuffled_indices = torch.randperm(self.patch_embed.num_patches) + self.num_prefix_tokens + visible_indices, _ = shuffled_indices[:self.patch_embed.num_patches - nmasks].sort() + return visible_indices + + def mask_embed(self, embed, visible_indices): + nmasks = self.patch_embed.num_patches - len(visible_indices) + mask_tokens = self.mask_token.expand(embed.size(0), nmasks, -1) + masked_features = torch.cat([embed[:, visible_indices, :], mask_tokens], dim=1) + return masked_features + + def forward_features(self, x, pos_embed=None, visible_indices=None, probe=False): + patch_embed = self.patch_embed(x) + x = self._pos_embed(patch_embed, pos_embed) + x = self.mask_embed(x, visible_indices) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + if probe: + return x, patch_embed + else: + return x + + def forward(self, x, pos_embed=None, visible_indices=None, probe=False): + assert pos_embed is not None + assert visible_indices is not None + if probe: + features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe) + x = self.forward_head(features) + return x, features, patch_embed + else: + features = self.forward_features(x, pos_embed, visible_indices, probe) + x = self.forward_head(features) + return x -- cgit v1.2.3