aboutsummaryrefslogtreecommitdiff
path: root/posrecon
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-19 14:00:21 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-19 14:00:21 +0800
commitbff36e9337bd9493e95588b8e342431eb31184f6 (patch)
tree16a23a1c1967e93bf6ba3fa2096b5beccaaa954b /posrecon
parent1877161db5056c0a86cfc680df94dbb7f98a438f (diff)
Add ViT with position shuffling/unshuffling methods
Diffstat (limited to 'posrecon')
-rw-r--r--posrecon/models.py107
1 files changed, 107 insertions, 0 deletions
diff --git a/posrecon/models.py b/posrecon/models.py
new file mode 100644
index 0000000..62cf322
--- /dev/null
+++ b/posrecon/models.py
@@ -0,0 +1,107 @@
+import math
+import torch
+from timm.models.helpers import named_apply, checkpoint_seq
+from timm.models.layers import trunc_normal_
+from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
+from torch import nn
+
+
+class ShuffledVisionTransformer(VisionTransformer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ del self.pos_embed
+
+ def init_weights(self, mode=''):
+ assert mode in ('jax', 'jax_nlhb', 'moco', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
+ if self.cls_token is not None:
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(get_init_weights_vit(mode, head_bias), self)
+
+ @staticmethod
+ def fixed_positional_encoding(embed_dim, embed_len, max_embed_len=5000):
+ """Fixed positional encoding from vanilla Transformer"""
+ position = torch.arange(max_embed_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10_000.) / embed_dim))
+ pos_embed = torch.zeros(1, max_embed_len, embed_dim)
+ pos_embed[:, :, 0::2] = torch.sin(position * div_term)
+ pos_embed[:, :, 1::2] = torch.cos(position * div_term)
+
+ return pos_embed[:, :embed_len, :]
+
+ def init_pos_embed(self, device, fixed=False):
+ num_patches = self.patch_embed.num_patches
+ embed_len = num_patches if self.no_embed_class else num_patches + self.num_prefix_tokens
+ if fixed:
+ pos_embed = self.fixed_positional_encoding(self.embed_dim, embed_len).to(device)
+ else:
+ pos_embed = (torch.randn(1, embed_len, self.embed_dim,
+ device=device) * .02).requires_grad_()
+ trunc_normal_(pos_embed, std=.02)
+ return pos_embed
+
+ def shuffle_pos_embed(self, pos_embed, shuff_rate=0.75):
+ embed_len = pos_embed.size(1)
+ nshuffs = int(embed_len * shuff_rate)
+ shuffled_indices = torch.randperm(embed_len)[:nshuffs]
+ if not self.no_embed_class:
+ shuffled_indices += self.num_prefix_tokens
+ ordered_shuffled_indices, unshuffled_indices = shuffled_indices.sort()
+ shuffled_pos_embed = pos_embed.clone()
+ shuffled_pos_embed[:, ordered_shuffled_indices, :] = shuffled_pos_embed[:, shuffled_indices, :]
+ return shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices
+
+ @staticmethod
+ def unshuffle_pos_embed(shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices):
+ pos_embed = shuffled_pos_embed.clone()
+ pos_embed[:, ordered_shuffled_indices, :] \
+ = pos_embed[:, ordered_shuffled_indices, :][:, unshuffled_indices, :]
+ return pos_embed
+
+ @staticmethod
+ def reshuffle_pos_embed(pos_embed, ordered_shuffled_indices):
+ nshuffs = ordered_shuffled_indices.size(0)
+ reshuffled_indices = ordered_shuffled_indices[torch.randperm(nshuffs)]
+ _, unreshuffled_indices = reshuffled_indices.sort()
+ reshuffled_pos_embed = pos_embed.clone()
+ reshuffled_pos_embed[:, ordered_shuffled_indices, :] = reshuffled_pos_embed[:, reshuffled_indices, :]
+ return reshuffled_pos_embed, unreshuffled_indices
+
+ def _pos_embed(self, x, pos_embed=None):
+ if self.no_embed_class:
+ # deit-3, updated JAX (big vision)
+ # position embedding does not overlap with class token, add then concat
+ x = x + pos_embed
+ if self.cls_token is not None:
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ else:
+ # original timm, JAX, and deit vit impl
+ # pos_embed has entry for class token, concat then add
+ if self.cls_token is not None:
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + pos_embed
+ return self.pos_drop(x)
+
+ def forward_features(self, x, pos_embed=None, probe=False):
+ patch_embed = self.patch_embed(x)
+ x = self._pos_embed(patch_embed, pos_embed)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ if probe:
+ return x, patch_embed
+ else:
+ return x
+
+ def forward(self, x, pos_embed=None, probe=False):
+ assert pos_embed is not None
+ if probe:
+ features, patch_embed = self.forward_features(x, pos_embed, probe)
+ x = self.forward_head(features)
+ return x, features, patch_embed
+ else:
+ features = self.forward_features(x, pos_embed, probe)
+ x = self.forward_head(features)
+ return x