aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--posrecon/models.py179
-rw-r--r--posrecon/pos_embed.py97
2 files changed, 275 insertions, 1 deletions
diff --git a/posrecon/models.py b/posrecon/models.py
index 7a2ce09..5f59510 100644
--- a/posrecon/models.py
+++ b/posrecon/models.py
@@ -3,10 +3,14 @@ from pathlib import Path
import math
import sys
import torch
+import torch.nn.functional as F
from timm.models.helpers import named_apply, checkpoint_seq
from timm.models.layers import trunc_normal_
-from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit
+from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit, PatchEmbed, Block
from torch import nn
+from typing import Callable
+
+from posrecon.pos_embed import get_2d_sincos_pos_embed
path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)
@@ -14,6 +18,179 @@ sys.path.insert(0, path)
from simclr.models import SimCLRBase
+class MaskedPosReconCLRViT(nn.Module):
+ """
+ Masked contrastive learning Vision Transformer w/ positional reconstruction
+ Default params are from ViT-Base.
+ """
+
+ def __init__(
+ self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: int = 4,
+ proj_dim: int = 128,
+ norm_layer: Callable = nn.LayerNorm,
+ ):
+ super(MaskedPosReconCLRViT, self).__init__()
+
+ # Encoder
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # Following DeiT-3, exclude pos_embed from cls_token
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim),
+ requires_grad=False)
+
+ self.blocks = nn.Sequential(*[
+ Block(embed_dim, num_heads, mlp_ratio,
+ qkv_bias=True, norm_layer=norm_layer)
+ for _ in range(depth)
+ ])
+ self.norm = norm_layer(embed_dim)
+
+ # Position predictor (linear layer equiv.)
+ self.pos_decoder = nn.Conv1d(embed_dim, embed_dim, kernel_size=1)
+
+ # Projection head
+ self.proj_head = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim),
+ nn.GELU(),
+ nn.Linear(embed_dim, proj_dim),
+ )
+
+ self.init_weights()
+
+ def init_weights(self):
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.size(-1), int(self.patch_embed.num_patches ** .5)
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Init weights in convolutional layers like in MLPs
+ patch_conv_weight = self.patch_embed.proj.weight.data
+ pos_conv_weight = self.pos_decoder.weight.data
+ nn.init.xavier_uniform_(patch_conv_weight.view(patch_conv_weight.size(0), -1))
+ nn.init.xavier_uniform_(pos_conv_weight.view(pos_conv_weight.size(0), -1))
+
+ nn.init.normal_(self.cls_token, std=.02)
+
+ self.apply(self._init_other_weights)
+
+ def _init_other_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @staticmethod
+ def rand_shuffle(x, pos_embed):
+ batch_size, seq_len, embed_dim = x.size()
+ # pos_embed: [1, seq_len, embed_dim]
+ batch_pos_embed = pos_embed.expand(batch_size, -1, -1)
+ # batch_pos_embed: [batch_size, seq_len, embed_dim]
+ noise = torch.rand(batch_size, seq_len, device=x.device)
+ shuffled_indices = noise.argsort()
+ # shuffled_indices: [batch_size, seq_len]
+ expand_shuffled_indices = shuffled_indices.unsqueeze(-1).expand(-1, -1, embed_dim)
+ # expand_shuffled_indices: [batch_size, seq_len, embed_dim]
+ batch_shuffled_pos_embed = batch_pos_embed.gather(1, expand_shuffled_indices)
+ # batch_shuffled_pos_embed: [batch_size, seq_len, embed_dim]
+ return x + batch_shuffled_pos_embed
+
+ @staticmethod
+ def rand_mask(x, mask_ratio):
+ batch_size, seq_len, embed_dim = x.size()
+ visible_len = int(seq_len * (1 - mask_ratio))
+ noise = torch.rand(batch_size, seq_len, device=x.device)
+ shuffled_indices = noise.argsort()
+ # shuffled_indices: [batch_size, seq_len]
+ unshuffled_indices = shuffled_indices.argsort()
+ # unshuffled_indices: [batch_size, seq_len]
+ visible_indices = shuffled_indices[:, :visible_len]
+ # visible_indices: [batch_size, seq_len * mask_ratio]
+ expand_visible_indices = visible_indices.unsqueeze(-1).expand(-1, -1, embed_dim)
+ # expand_visible_indices: [batch_size, seq_len * mask_ratio, embed_dim]
+ x_masked = x.gather(1, expand_visible_indices)
+ # x_masked: [batch_size, seq_len * mask_ratio, embed_dim]
+
+ return x_masked, expand_visible_indices
+
+ def forward_encoder(self, x, mask_ratio):
+ x = self.patch_embed(x)
+
+ x = self.rand_shuffle(x, self.pos_embed)
+ # batch_size*2, seq_len, embed_dim
+ x, visible_indices = self.rand_mask(x, mask_ratio)
+ # batch_size*2, seq_len * mask_ratio, embed_dim
+
+ # Concatenate [CLS] tokens w/o pos_embed
+ cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ # batch_size*2, 1 + seq_len * mask_ratio, embed_dim
+
+ x = self.blocks(x)
+ x = self.norm(x)
+
+ return x, visible_indices
+
+ def forward_pos_decoder(self, latent):
+ # Exchange channel and length dimension for Conv1d
+ latent = latent.permute(0, 2, 1)
+ pos_embed_pred = self.pos_decoder(latent)
+ # Restore dimension
+ pos_embed_pred = pos_embed_pred.permute(0, 2, 1)
+
+ return pos_embed_pred
+
+ def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.01):
+ batch_size, _, _ = batch_pos_embed_pred.size()
+ batch_pos_embed_targ = self.pos_embed.expand(batch_size, -1, -1)
+ # Only compute loss on visible patches
+ visible_pos_embed_targ = batch_pos_embed_targ.gather(1, vis_ids)
+ loss_recon = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ)
+
+ bz_clr = batch_size // 2
+ feat_norm = F.normalize(features)
+ feat1_norm, feat2_norm = feat_norm.split(bz_clr)
+ # feat1_norm, feat2_norm: [batch_size, proj_dim]
+ logits = feat1_norm @ feat2_norm.T
+ # logits: [batch_size, batch_size]
+ pos_logits_mask = torch.eye(bz_clr, dtype=torch.bool)
+ pos_logits = logits[pos_logits_mask].unsqueeze(-1)
+ # pos_logits: [batch_size, 1]
+ neg_logits = logits[~pos_logits_mask].view(bz_clr, -1)
+ # neg_logits: [batch_size, batch_size - 1]
+ # Put the positive at first (0-th) and maximize its likelihood
+ logits = torch.cat([pos_logits, neg_logits], dim=1)
+ # logits: [batch_size, batch_size]
+ labels = torch.zeros(bz_clr, dtype=torch.long, device=features.device)
+ loss_clr = F.cross_entropy(logits / temp, labels)
+ acc_clr = (logits.argmax(dim=1) == labels).float().mean()
+
+ return loss_recon, loss_clr, acc_clr
+
+ def forward(self, img, mask_ratio=0.75, temp=0.01):
+ # img: [batch_size*2, in_chans, height, weight]
+ latent, vis_ids = self.forward_encoder(img, mask_ratio)
+ # latent: [batch_size*2, 1 + seq_len * mask_ratio, embed_dim]
+ pos_pred = self.forward_pos_decoder(latent[:, 1:, :])
+ # pos_pred: [batch_size*2, seq_len * mask_ratio, embed_dim]
+ feat = self.proj_head(latent[:, 0, :])
+ # reps: [batch_size*2, proj_dim]
+ loss_recon, loss_clr, acc_clr = self.forward_loss(pos_pred, vis_ids, feat, temp)
+ return latent, pos_pred, feat, loss_recon, loss_clr, acc_clr
+
+
class ShuffledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/posrecon/pos_embed.py b/posrecon/pos_embed.py
new file mode 100644
index 0000000..ea3fe7d
--- /dev/null
+++ b/posrecon/pos_embed.py
@@ -0,0 +1,97 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed