diff options
Diffstat (limited to 'posrecon')
-rw-r--r-- | posrecon/models.py | 98 |
1 files changed, 73 insertions, 25 deletions
diff --git a/posrecon/models.py b/posrecon/models.py index 5f59510..1de435e 100644 --- a/posrecon/models.py +++ b/posrecon/models.py @@ -18,6 +18,31 @@ sys.path.insert(0, path) from simclr.models import SimCLRBase +class SyncFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor): + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [torch.zeros_like(tensor) + for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.cat(gathered_tensor, 0) + + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input, + op=torch.distributed.ReduceOp.SUM, + async_op=False) + + idx_from = torch.distributed.get_rank() * ctx.batch_size + idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size + return grad_input[idx_from:idx_to] + + class MaskedPosReconCLRViT(nn.Module): """ Masked contrastive learning Vision Transformer w/ positional reconstruction @@ -152,32 +177,55 @@ class MaskedPosReconCLRViT(nn.Module): return pos_embed_pred - def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.01): - batch_size, _, _ = batch_pos_embed_pred.size() + def pos_recon_loss(self, batch_pos_embed_pred, vis_ids): + batch_size = batch_pos_embed_pred.size(0) + # self.pos_embed: [1, seq_len, embed_dim] + # batch_pos_embed_pred: [batch_size*2, seq_len, embed_dim] + # vis_ids: [batch_size*2, seq_len * mask_ratio, embed_dim] batch_pos_embed_targ = self.pos_embed.expand(batch_size, -1, -1) + # batch_pos_embed_targ: [batch_size*2, seq_len, embed_dim] # Only compute loss on visible patches visible_pos_embed_targ = batch_pos_embed_targ.gather(1, vis_ids) - loss_recon = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ) - - bz_clr = batch_size // 2 - feat_norm = F.normalize(features) - feat1_norm, feat2_norm = feat_norm.split(bz_clr) - # feat1_norm, feat2_norm: [batch_size, proj_dim] - logits = feat1_norm @ feat2_norm.T - # logits: [batch_size, batch_size] - pos_logits_mask = torch.eye(bz_clr, dtype=torch.bool) - pos_logits = logits[pos_logits_mask].unsqueeze(-1) - # pos_logits: [batch_size, 1] - neg_logits = logits[~pos_logits_mask].view(bz_clr, -1) - # neg_logits: [batch_size, batch_size - 1] - # Put the positive at first (0-th) and maximize its likelihood - logits = torch.cat([pos_logits, neg_logits], dim=1) - # logits: [batch_size, batch_size] - labels = torch.zeros(bz_clr, dtype=torch.long, device=features.device) - loss_clr = F.cross_entropy(logits / temp, labels) - acc_clr = (logits.argmax(dim=1) == labels).float().mean() - - return loss_recon, loss_clr, acc_clr + # visible_pos_embed_targ: [batch_size*2, seq_len * mask_ratio, embed_dim] + loss = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ) + return loss + + @staticmethod + def info_nce_loss(features, temp, eps=1e-6): + feat = F.normalize(features) + # feat: [batch_size*2, proj_dim] + feat = torch.stack(feat.chunk(2), dim=1) + # feat: [batch_size, 2, proj_dim] + if torch.distributed.is_available() and torch.distributed.is_initialized(): + feat = SyncFunction.apply(feat) + # feat: [batch_size (* world_size), 2, proj_dim] + feat1, feat2 = feat[:, 0, :], feat[:, 1, :] + # feat{1,2}: [batch_size (* world_size), proj_dim] + feat = torch.cat((feat1, feat2)) + # feat: [batch_size*2 (* world_size), proj_dim] + + # All samples, filling diagonal to remove identity similarity ((2N)^2 - 2N) + all_sim = (feat @ feat.T).fill_diagonal_(0) + # all_sim: [batch_size*2 (* world_size), batch_size*2 (* world_size)] + all_ = torch.exp(all_sim / temp).sum(-1) + # all_: [batch_size*2 (* world_size)] + + # Positive samples (2N) + pos_sim = (feat1 * feat2).sum(-1) + # pos_sim: [batch_size (* world_size)] + pos = torch.exp(pos_sim / temp) + # Following all samples, compute positive similarity twice + pos = torch.cat((pos, pos)) + # pos: [batch_size*2 (* world_size)] + + loss = -torch.log(pos / (all_ + eps)).mean() + + return loss + + def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.1): + loss_recon = self.pos_recon_loss(batch_pos_embed_pred, vis_ids) + loss_clr = self.info_nce_loss(features, temp) + return loss_recon, loss_clr def forward(self, img, mask_ratio=0.75, temp=0.01): # img: [batch_size*2, in_chans, height, weight] @@ -187,8 +235,8 @@ class MaskedPosReconCLRViT(nn.Module): # pos_pred: [batch_size*2, seq_len * mask_ratio, embed_dim] feat = self.proj_head(latent[:, 0, :]) # reps: [batch_size*2, proj_dim] - loss_recon, loss_clr, acc_clr = self.forward_loss(pos_pred, vis_ids, feat, temp) - return latent, pos_pred, feat, loss_recon, loss_clr, acc_clr + loss_recon, loss_clr = self.forward_loss(pos_pred, vis_ids, feat, temp) + return latent, pos_pred, feat, loss_recon, loss_clr class ShuffledVisionTransformer(VisionTransformer): |