aboutsummaryrefslogtreecommitdiff
path: root/posrecon/models.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-19 16:53:24 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-19 16:53:24 +0800
commitea13f7e3fc3aec1b48d61896dcb3032897cc4b7a (patch)
treef786666c439d39e5c4c1cda6bc698258f7795369 /posrecon/models.py
parent26420733f98292639b9addb02e73fd8f12ee82e7 (diff)
Add ViT with position shuffling/unshuffling and masking methods
Diffstat (limited to 'posrecon/models.py')
-rw-r--r--posrecon/models.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/posrecon/models.py b/posrecon/models.py
index 62cf322..9e8cb22 100644
--- a/posrecon/models.py
+++ b/posrecon/models.py
@@ -105,3 +105,48 @@ class ShuffledVisionTransformer(VisionTransformer):
features = self.forward_features(x, pos_embed, probe)
x = self.forward_head(features)
return x
+
+
+class MaskedShuffledVisionTransformer(ShuffledVisionTransformer):
+ def __init__(self, *args, **kwargs):
+ super(MaskedShuffledVisionTransformer, self).__init__(*args, **kwargs)
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim) * 0.02)
+ trunc_normal_(self.mask_token, std=.02)
+
+ def generate_masks(self, mask_rate=0.75):
+ nmasks = int(self.patch_embed.num_patches * mask_rate)
+ shuffled_indices = torch.randperm(self.patch_embed.num_patches) + self.num_prefix_tokens
+ visible_indices, _ = shuffled_indices[:self.patch_embed.num_patches - nmasks].sort()
+ return visible_indices
+
+ def mask_embed(self, embed, visible_indices):
+ nmasks = self.patch_embed.num_patches - len(visible_indices)
+ mask_tokens = self.mask_token.expand(embed.size(0), nmasks, -1)
+ masked_features = torch.cat([embed[:, visible_indices, :], mask_tokens], dim=1)
+ return masked_features
+
+ def forward_features(self, x, pos_embed=None, visible_indices=None, probe=False):
+ patch_embed = self.patch_embed(x)
+ x = self._pos_embed(patch_embed, pos_embed)
+ x = self.mask_embed(x, visible_indices)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ if probe:
+ return x, patch_embed
+ else:
+ return x
+
+ def forward(self, x, pos_embed=None, visible_indices=None, probe=False):
+ assert pos_embed is not None
+ assert visible_indices is not None
+ if probe:
+ features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe)
+ x = self.forward_head(features)
+ return x, features, patch_embed
+ else:
+ features = self.forward_features(x, pos_embed, visible_indices, probe)
+ x = self.forward_head(features)
+ return x