aboutsummaryrefslogtreecommitdiff
path: root/posrecon/models.py
blob: 62cf322e71f0f0e19636c298b84ad5d99249d65d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import math
import torch
from timm.models.helpers import named_apply, checkpoint_seq
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
from torch import nn


class ShuffledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        del self.pos_embed

    def init_weights(self, mode=''):
        assert mode in ('jax', 'jax_nlhb', 'moco', '')
        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
        if self.cls_token is not None:
            nn.init.normal_(self.cls_token, std=1e-6)
        named_apply(get_init_weights_vit(mode, head_bias), self)

    @staticmethod
    def fixed_positional_encoding(embed_dim, embed_len, max_embed_len=5000):
        """Fixed positional encoding from vanilla Transformer"""
        position = torch.arange(max_embed_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10_000.) / embed_dim))
        pos_embed = torch.zeros(1, max_embed_len, embed_dim)
        pos_embed[:, :, 0::2] = torch.sin(position * div_term)
        pos_embed[:, :, 1::2] = torch.cos(position * div_term)

        return pos_embed[:, :embed_len, :]

    def init_pos_embed(self, device, fixed=False):
        num_patches = self.patch_embed.num_patches
        embed_len = num_patches if self.no_embed_class else num_patches + self.num_prefix_tokens
        if fixed:
            pos_embed = self.fixed_positional_encoding(self.embed_dim, embed_len).to(device)
        else:
            pos_embed = (torch.randn(1, embed_len, self.embed_dim,
                                     device=device) * .02).requires_grad_()
            trunc_normal_(pos_embed, std=.02)
        return pos_embed

    def shuffle_pos_embed(self, pos_embed, shuff_rate=0.75):
        embed_len = pos_embed.size(1)
        nshuffs = int(embed_len * shuff_rate)
        shuffled_indices = torch.randperm(embed_len)[:nshuffs]
        if not self.no_embed_class:
            shuffled_indices += self.num_prefix_tokens
        ordered_shuffled_indices, unshuffled_indices = shuffled_indices.sort()
        shuffled_pos_embed = pos_embed.clone()
        shuffled_pos_embed[:, ordered_shuffled_indices, :] = shuffled_pos_embed[:, shuffled_indices, :]
        return shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices

    @staticmethod
    def unshuffle_pos_embed(shuffled_pos_embed, unshuffled_indices, ordered_shuffled_indices):
        pos_embed = shuffled_pos_embed.clone()
        pos_embed[:, ordered_shuffled_indices, :] \
            = pos_embed[:, ordered_shuffled_indices, :][:, unshuffled_indices, :]
        return pos_embed

    @staticmethod
    def reshuffle_pos_embed(pos_embed, ordered_shuffled_indices):
        nshuffs = ordered_shuffled_indices.size(0)
        reshuffled_indices = ordered_shuffled_indices[torch.randperm(nshuffs)]
        _, unreshuffled_indices = reshuffled_indices.sort()
        reshuffled_pos_embed = pos_embed.clone()
        reshuffled_pos_embed[:, ordered_shuffled_indices, :] = reshuffled_pos_embed[:, reshuffled_indices, :]
        return reshuffled_pos_embed, unreshuffled_indices

    def _pos_embed(self, x, pos_embed=None):
        if self.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + pos_embed
            if self.cls_token is not None:
                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if self.cls_token is not None:
                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            x = x + pos_embed
        return self.pos_drop(x)

    def forward_features(self, x, pos_embed=None, probe=False):
        patch_embed = self.patch_embed(x)
        x = self._pos_embed(patch_embed, pos_embed)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        x = self.norm(x)
        if probe:
            return x, patch_embed
        else:
            return x

    def forward(self, x, pos_embed=None, probe=False):
        assert pos_embed is not None
        if probe:
            features, patch_embed = self.forward_features(x, pos_embed, probe)
            x = self.forward_head(features)
            return x, features, patch_embed
        else:
            features = self.forward_features(x, pos_embed, probe)
            x = self.forward_head(features)
            return x