1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
|
import math
import torch
from timm.models.helpers import named_apply, checkpoint_seq
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
from torch import nn
class ShuffledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
del self.pos_embed
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
@staticmethod
def fixed_positional_encoding(embed_dim, embed_len, max_embed_len=5000):
"""Fixed positional encoding from vanilla Transformer"""
position = torch.arange(max_embed_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10_000.) / embed_dim))
pos_embed = torch.zeros(1, max_embed_len, embed_dim)
pos_embed[:, :, 0::2] = torch.sin(position * div_term)
pos_embed[:, :, 1::2] = torch.cos(position * div_term)
return pos_embed[:, :embed_len, :]
def init_pos_embed(self, device, fixed=False):
num_patches = self.patch_embed.num_patches
embed_len = num_patches if self.no_embed_class else num_patches + self.num_prefix_tokens
if fixed:
pos_embed = self.fixed_positional_encoding(self.embed_dim, embed_len).to(device)
else:
pos_embed = (torch.randn(1, embed_len, self.embed_dim,
device=device) * .02).requires_grad_()
trunc_normal_(pos_embed, std=.02)
return pos_embed
def shuffle_pos_embed(self, pos_embed, shuff_rate=0.75):
embed_len = pos_embed.size(1)
nshuffs = int(embed_len * shuff_rate)
shuffled_indices = torch.randperm(embed_len)[:nshuffs]
if not self.no_embed_class:
shuffled_indices += self.num_prefix_tokens
ordered_shuffled_indices, unshuffled_indices = shuffled_indices.sort()
shuffled_pos_embed = pos_embed.clone()
shuffled_pos_embed[:, ordered_shuffled_indices, :] = shuffled_pos_embed[:, shuffled_indices, :]
return shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices
@staticmethod
def unshuffle_pos_embed(shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices):
pos_embed = shuffled_pos_embed.clone()
pos_embed[:, ordered_shuffled_indices, :] \
= pos_embed[:, ordered_shuffled_indices, :][:, unshuffled_indices, :]
return pos_embed
@staticmethod
def reshuffle_pos_embed(pos_embed, ordered_shuffled_indices):
nshuffs = ordered_shuffled_indices.size(0)
reshuffled_indices = ordered_shuffled_indices[torch.randperm(nshuffs)]
_, unreshuffled_indices = reshuffled_indices.sort()
reshuffled_pos_embed = pos_embed.clone()
reshuffled_pos_embed[:, ordered_shuffled_indices, :] = reshuffled_pos_embed[:, reshuffled_indices, :]
return reshuffled_pos_embed, unreshuffled_indices
def _pos_embed(self, x, pos_embed=None):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + pos_embed
return self.pos_drop(x)
def forward_features(self, x, pos_embed=None, probe=False):
patch_embed = self.patch_embed(x)
x = self._pos_embed(patch_embed, pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
if probe:
return x, patch_embed
else:
return x
def forward(self, x, pos_embed=None, probe=False):
assert pos_embed is not None
if probe:
features, patch_embed = self.forward_features(x, pos_embed, probe)
x = self.forward_head(features)
return x, features, patch_embed
else:
features = self.forward_features(x, pos_embed, probe)
x = self.forward_head(features)
return x
class MaskedShuffledVisionTransformer(ShuffledVisionTransformer):
def __init__(self, *args, **kwargs):
super(MaskedShuffledVisionTransformer, self).__init__(*args, **kwargs)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim) * 0.02)
trunc_normal_(self.mask_token, std=.02)
def generate_masks(self, mask_rate=0.75):
nmasks = int(self.patch_embed.num_patches * mask_rate)
shuffled_indices = torch.randperm(self.patch_embed.num_patches) + self.num_prefix_tokens
visible_indices, _ = shuffled_indices[:self.patch_embed.num_patches - nmasks].sort()
return visible_indices
def mask_embed(self, embed, visible_indices):
nmasks = self.patch_embed.num_patches - len(visible_indices)
mask_tokens = self.mask_token.expand(embed.size(0), nmasks, -1)
masked_features = torch.cat([embed[:, visible_indices, :], mask_tokens], dim=1)
return masked_features
def forward_features(self, x, pos_embed=None, visible_indices=None, probe=False):
patch_embed = self.patch_embed(x)
x = self._pos_embed(patch_embed, pos_embed)
x = self.mask_embed(x, visible_indices)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
if probe:
return x, patch_embed
else:
return x
def forward(self, x, pos_embed=None, visible_indices=None, probe=False):
assert pos_embed is not None
assert visible_indices is not None
if probe:
features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe)
x = self.forward_head(features)
return x, features, patch_embed
else:
features = self.forward_features(x, pos_embed, visible_indices, probe)
x = self.forward_head(features)
return x
|