diff options
Diffstat (limited to 'posrecon')
-rw-r--r-- | posrecon/models.py | 179 | ||||
-rw-r--r-- | posrecon/pos_embed.py | 97 |
2 files changed, 275 insertions, 1 deletions
diff --git a/posrecon/models.py b/posrecon/models.py index 7a2ce09..5f59510 100644 --- a/posrecon/models.py +++ b/posrecon/models.py @@ -3,10 +3,14 @@ from pathlib import Path import math import sys import torch +import torch.nn.functional as F from timm.models.helpers import named_apply, checkpoint_seq from timm.models.layers import trunc_normal_ -from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit +from timm.models.vision_transformer import VisionTransformer, get_init_weights_vit, PatchEmbed, Block from torch import nn +from typing import Callable + +from posrecon.pos_embed import get_2d_sincos_pos_embed path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) @@ -14,6 +18,179 @@ sys.path.insert(0, path) from simclr.models import SimCLRBase +class MaskedPosReconCLRViT(nn.Module): + """ + Masked contrastive learning Vision Transformer w/ positional reconstruction + Default params are from ViT-Base. + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: int = 4, + proj_dim: int = 128, + norm_layer: Callable = nn.LayerNorm, + ): + super(MaskedPosReconCLRViT, self).__init__() + + # Encoder + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # Following DeiT-3, exclude pos_embed from cls_token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), + requires_grad=False) + + self.blocks = nn.Sequential(*[ + Block(embed_dim, num_heads, mlp_ratio, + qkv_bias=True, norm_layer=norm_layer) + for _ in range(depth) + ]) + self.norm = norm_layer(embed_dim) + + # Position predictor (linear layer equiv.) + self.pos_decoder = nn.Conv1d(embed_dim, embed_dim, kernel_size=1) + + # Projection head + self.proj_head = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, proj_dim), + ) + + self.init_weights() + + def init_weights(self): + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.size(-1), int(self.patch_embed.num_patches ** .5) + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Init weights in convolutional layers like in MLPs + patch_conv_weight = self.patch_embed.proj.weight.data + pos_conv_weight = self.pos_decoder.weight.data + nn.init.xavier_uniform_(patch_conv_weight.view(patch_conv_weight.size(0), -1)) + nn.init.xavier_uniform_(pos_conv_weight.view(pos_conv_weight.size(0), -1)) + + nn.init.normal_(self.cls_token, std=.02) + + self.apply(self._init_other_weights) + + def _init_other_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @staticmethod + def rand_shuffle(x, pos_embed): + batch_size, seq_len, embed_dim = x.size() + # pos_embed: [1, seq_len, embed_dim] + batch_pos_embed = pos_embed.expand(batch_size, -1, -1) + # batch_pos_embed: [batch_size, seq_len, embed_dim] + noise = torch.rand(batch_size, seq_len, device=x.device) + shuffled_indices = noise.argsort() + # shuffled_indices: [batch_size, seq_len] + expand_shuffled_indices = shuffled_indices.unsqueeze(-1).expand(-1, -1, embed_dim) + # expand_shuffled_indices: [batch_size, seq_len, embed_dim] + batch_shuffled_pos_embed = batch_pos_embed.gather(1, expand_shuffled_indices) + # batch_shuffled_pos_embed: [batch_size, seq_len, embed_dim] + return x + batch_shuffled_pos_embed + + @staticmethod + def rand_mask(x, mask_ratio): + batch_size, seq_len, embed_dim = x.size() + visible_len = int(seq_len * (1 - mask_ratio)) + noise = torch.rand(batch_size, seq_len, device=x.device) + shuffled_indices = noise.argsort() + # shuffled_indices: [batch_size, seq_len] + unshuffled_indices = shuffled_indices.argsort() + # unshuffled_indices: [batch_size, seq_len] + visible_indices = shuffled_indices[:, :visible_len] + # visible_indices: [batch_size, seq_len * mask_ratio] + expand_visible_indices = visible_indices.unsqueeze(-1).expand(-1, -1, embed_dim) + # expand_visible_indices: [batch_size, seq_len * mask_ratio, embed_dim] + x_masked = x.gather(1, expand_visible_indices) + # x_masked: [batch_size, seq_len * mask_ratio, embed_dim] + + return x_masked, expand_visible_indices + + def forward_encoder(self, x, mask_ratio): + x = self.patch_embed(x) + + x = self.rand_shuffle(x, self.pos_embed) + # batch_size*2, seq_len, embed_dim + x, visible_indices = self.rand_mask(x, mask_ratio) + # batch_size*2, seq_len * mask_ratio, embed_dim + + # Concatenate [CLS] tokens w/o pos_embed + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + # batch_size*2, 1 + seq_len * mask_ratio, embed_dim + + x = self.blocks(x) + x = self.norm(x) + + return x, visible_indices + + def forward_pos_decoder(self, latent): + # Exchange channel and length dimension for Conv1d + latent = latent.permute(0, 2, 1) + pos_embed_pred = self.pos_decoder(latent) + # Restore dimension + pos_embed_pred = pos_embed_pred.permute(0, 2, 1) + + return pos_embed_pred + + def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.01): + batch_size, _, _ = batch_pos_embed_pred.size() + batch_pos_embed_targ = self.pos_embed.expand(batch_size, -1, -1) + # Only compute loss on visible patches + visible_pos_embed_targ = batch_pos_embed_targ.gather(1, vis_ids) + loss_recon = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ) + + bz_clr = batch_size // 2 + feat_norm = F.normalize(features) + feat1_norm, feat2_norm = feat_norm.split(bz_clr) + # feat1_norm, feat2_norm: [batch_size, proj_dim] + logits = feat1_norm @ feat2_norm.T + # logits: [batch_size, batch_size] + pos_logits_mask = torch.eye(bz_clr, dtype=torch.bool) + pos_logits = logits[pos_logits_mask].unsqueeze(-1) + # pos_logits: [batch_size, 1] + neg_logits = logits[~pos_logits_mask].view(bz_clr, -1) + # neg_logits: [batch_size, batch_size - 1] + # Put the positive at first (0-th) and maximize its likelihood + logits = torch.cat([pos_logits, neg_logits], dim=1) + # logits: [batch_size, batch_size] + labels = torch.zeros(bz_clr, dtype=torch.long, device=features.device) + loss_clr = F.cross_entropy(logits / temp, labels) + acc_clr = (logits.argmax(dim=1) == labels).float().mean() + + return loss_recon, loss_clr, acc_clr + + def forward(self, img, mask_ratio=0.75, temp=0.01): + # img: [batch_size*2, in_chans, height, weight] + latent, vis_ids = self.forward_encoder(img, mask_ratio) + # latent: [batch_size*2, 1 + seq_len * mask_ratio, embed_dim] + pos_pred = self.forward_pos_decoder(latent[:, 1:, :]) + # pos_pred: [batch_size*2, seq_len * mask_ratio, embed_dim] + feat = self.proj_head(latent[:, 0, :]) + # reps: [batch_size*2, proj_dim] + loss_recon, loss_clr, acc_clr = self.forward_loss(pos_pred, vis_ids, feat, temp) + return latent, pos_pred, feat, loss_recon, loss_clr, acc_clr + + class ShuffledVisionTransformer(VisionTransformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/posrecon/pos_embed.py b/posrecon/pos_embed.py new file mode 100644 index 0000000..ea3fe7d --- /dev/null +++ b/posrecon/pos_embed.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed |