1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
import math
import torch
from timm.models.helpers import named_apply, checkpoint_seq
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
from torch import nn
class ShuffledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
del self.pos_embed
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
@staticmethod
def fixed_positional_encoding(embed_dim, embed_len, max_embed_len=5000):
"""Fixed positional encoding from vanilla Transformer"""
position = torch.arange(max_embed_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10_000.) / embed_dim))
pos_embed = torch.zeros(1, max_embed_len, embed_dim)
pos_embed[:, :, 0::2] = torch.sin(position * div_term)
pos_embed[:, :, 1::2] = torch.cos(position * div_term)
return pos_embed[:, :embed_len, :]
def init_pos_embed(self, device, fixed=False):
num_patches = self.patch_embed.num_patches
embed_len = num_patches if self.no_embed_class else num_patches + self.num_prefix_tokens
if fixed:
pos_embed = self.fixed_positional_encoding(self.embed_dim, embed_len).to(device)
else:
pos_embed = (torch.randn(1, embed_len, self.embed_dim,
device=device) * .02).requires_grad_()
trunc_normal_(pos_embed, std=.02)
return pos_embed
def shuffle_pos_embed(self, pos_embed, shuff_rate=0.75):
embed_len = pos_embed.size(1)
nshuffs = int(embed_len * shuff_rate)
shuffled_indices = torch.randperm(embed_len)[:nshuffs]
if not self.no_embed_class:
shuffled_indices += self.num_prefix_tokens
ordered_shuffled_indices, unshuffled_indices = shuffled_indices.sort()
shuffled_pos_embed = pos_embed.clone()
shuffled_pos_embed[:, ordered_shuffled_indices, :] = shuffled_pos_embed[:, shuffled_indices, :]
return shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices
@staticmethod
def unshuffle_pos_embed(shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices):
pos_embed = shuffled_pos_embed.clone()
pos_embed[:, ordered_shuffled_indices, :] \
= pos_embed[:, ordered_shuffled_indices, :][:, unshuffled_indices, :]
return pos_embed
@staticmethod
def reshuffle_pos_embed(pos_embed, ordered_shuffled_indices):
nshuffs = ordered_shuffled_indices.size(0)
reshuffled_indices = ordered_shuffled_indices[torch.randperm(nshuffs)]
_, unreshuffled_indices = reshuffled_indices.sort()
reshuffled_pos_embed = pos_embed.clone()
reshuffled_pos_embed[:, ordered_shuffled_indices, :] = reshuffled_pos_embed[:, reshuffled_indices, :]
return reshuffled_pos_embed, unreshuffled_indices
def _pos_embed(self, x, pos_embed=None):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + pos_embed
return self.pos_drop(x)
def forward_features(self, x, pos_embed=None, probe=False):
patch_embed = self.patch_embed(x)
x = self._pos_embed(patch_embed, pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
if probe:
return x, patch_embed
else:
return x
def forward(self, x, pos_embed=None, probe=False):
assert pos_embed is not None
if probe:
features, patch_embed = self.forward_features(x, pos_embed, probe)
x = self.forward_head(features)
return x, features, patch_embed
else:
features = self.forward_features(x, pos_embed, probe)
x = self.forward_head(features)
return x
|