diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-19 16:53:24 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-19 16:53:24 +0800 | 
| commit | ea13f7e3fc3aec1b48d61896dcb3032897cc4b7a (patch) | |
| tree | f786666c439d39e5c4c1cda6bc698258f7795369 /posrecon | |
| parent | 26420733f98292639b9addb02e73fd8f12ee82e7 (diff) | |
Add ViT with position shuffling/unshuffling and masking methods
Diffstat (limited to 'posrecon')
| -rw-r--r-- | posrecon/models.py | 45 | 
1 files changed, 45 insertions, 0 deletions
| diff --git a/posrecon/models.py b/posrecon/models.py index 62cf322..9e8cb22 100644 --- a/posrecon/models.py +++ b/posrecon/models.py @@ -105,3 +105,48 @@ class ShuffledVisionTransformer(VisionTransformer):              features = self.forward_features(x, pos_embed, probe)              x = self.forward_head(features)              return x + + +class MaskedShuffledVisionTransformer(ShuffledVisionTransformer): +    def __init__(self, *args, **kwargs): +        super(MaskedShuffledVisionTransformer, self).__init__(*args, **kwargs) +        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim) * 0.02) +        trunc_normal_(self.mask_token, std=.02) + +    def generate_masks(self, mask_rate=0.75): +        nmasks = int(self.patch_embed.num_patches * mask_rate) +        shuffled_indices = torch.randperm(self.patch_embed.num_patches) + self.num_prefix_tokens +        visible_indices, _ = shuffled_indices[:self.patch_embed.num_patches - nmasks].sort() +        return visible_indices + +    def mask_embed(self, embed, visible_indices): +        nmasks = self.patch_embed.num_patches - len(visible_indices) +        mask_tokens = self.mask_token.expand(embed.size(0), nmasks, -1) +        masked_features = torch.cat([embed[:, visible_indices, :], mask_tokens], dim=1) +        return masked_features + +    def forward_features(self, x, pos_embed=None, visible_indices=None, probe=False): +        patch_embed = self.patch_embed(x) +        x = self._pos_embed(patch_embed, pos_embed) +        x = self.mask_embed(x, visible_indices) +        if self.grad_checkpointing and not torch.jit.is_scripting(): +            x = checkpoint_seq(self.blocks, x) +        else: +            x = self.blocks(x) +        x = self.norm(x) +        if probe: +            return x, patch_embed +        else: +            return x + +    def forward(self, x, pos_embed=None, visible_indices=None, probe=False): +        assert pos_embed is not None +        assert visible_indices is not None +        if probe: +            features, patch_embed = self.forward_features(x, pos_embed, visible_indices, probe) +            x = self.forward_head(features) +            return x, features, patch_embed +        else: +            features = self.forward_features(x, pos_embed, visible_indices, probe) +            x = self.forward_head(features) +            return x | 
