From 33a95dca1dee3a1a768cf557d4e2a878416a3a96 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Sat, 27 Aug 2022 13:11:16 +0800
Subject: Make InfoNCE loss distributable

---
 posrecon/models.py | 98 ++++++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 73 insertions(+), 25 deletions(-)

(limited to 'posrecon')

diff --git a/posrecon/models.py b/posrecon/models.py
index 5f59510..1de435e 100644
--- a/posrecon/models.py
+++ b/posrecon/models.py
@@ -18,6 +18,31 @@ sys.path.insert(0, path)
 from simclr.models import SimCLRBase
 
 
+class SyncFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, tensor):
+        ctx.batch_size = tensor.shape[0]
+
+        gathered_tensor = [torch.zeros_like(tensor)
+                           for _ in range(torch.distributed.get_world_size())]
+
+        torch.distributed.all_gather(gathered_tensor, tensor)
+        gathered_tensor = torch.cat(gathered_tensor, 0)
+
+        return gathered_tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_input = grad_output.clone()
+        torch.distributed.all_reduce(grad_input,
+                                     op=torch.distributed.ReduceOp.SUM,
+                                     async_op=False)
+
+        idx_from = torch.distributed.get_rank() * ctx.batch_size
+        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
+        return grad_input[idx_from:idx_to]
+
+
 class MaskedPosReconCLRViT(nn.Module):
     """
     Masked contrastive learning Vision Transformer w/ positional reconstruction
@@ -152,32 +177,55 @@ class MaskedPosReconCLRViT(nn.Module):
 
         return pos_embed_pred
 
-    def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.01):
-        batch_size, _, _ = batch_pos_embed_pred.size()
+    def pos_recon_loss(self, batch_pos_embed_pred, vis_ids):
+        batch_size = batch_pos_embed_pred.size(0)
+        # self.pos_embed: [1, seq_len, embed_dim]
+        # batch_pos_embed_pred: [batch_size*2, seq_len, embed_dim]
+        # vis_ids: [batch_size*2, seq_len * mask_ratio, embed_dim]
         batch_pos_embed_targ = self.pos_embed.expand(batch_size, -1, -1)
+        # batch_pos_embed_targ: [batch_size*2, seq_len, embed_dim]
         # Only compute loss on visible patches
         visible_pos_embed_targ = batch_pos_embed_targ.gather(1, vis_ids)
-        loss_recon = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ)
-
-        bz_clr = batch_size // 2
-        feat_norm = F.normalize(features)
-        feat1_norm, feat2_norm = feat_norm.split(bz_clr)
-        # feat1_norm, feat2_norm: [batch_size, proj_dim]
-        logits = feat1_norm @ feat2_norm.T
-        # logits: [batch_size, batch_size]
-        pos_logits_mask = torch.eye(bz_clr, dtype=torch.bool)
-        pos_logits = logits[pos_logits_mask].unsqueeze(-1)
-        # pos_logits: [batch_size, 1]
-        neg_logits = logits[~pos_logits_mask].view(bz_clr, -1)
-        # neg_logits: [batch_size, batch_size - 1]
-        # Put the positive at first (0-th) and maximize its likelihood
-        logits = torch.cat([pos_logits, neg_logits], dim=1)
-        # logits: [batch_size, batch_size]
-        labels = torch.zeros(bz_clr, dtype=torch.long, device=features.device)
-        loss_clr = F.cross_entropy(logits / temp, labels)
-        acc_clr = (logits.argmax(dim=1) == labels).float().mean()
-
-        return loss_recon, loss_clr, acc_clr
+        # visible_pos_embed_targ: [batch_size*2, seq_len * mask_ratio, embed_dim]
+        loss = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ)
+        return loss
+
+    @staticmethod
+    def info_nce_loss(features, temp, eps=1e-6):
+        feat = F.normalize(features)
+        # feat: [batch_size*2, proj_dim]
+        feat = torch.stack(feat.chunk(2), dim=1)
+        # feat: [batch_size, 2, proj_dim]
+        if torch.distributed.is_available() and torch.distributed.is_initialized():
+            feat = SyncFunction.apply(feat)
+        # feat: [batch_size (* world_size), 2, proj_dim]
+        feat1, feat2 = feat[:, 0, :], feat[:, 1, :]
+        # feat{1,2}: [batch_size (* world_size), proj_dim]
+        feat = torch.cat((feat1, feat2))
+        # feat: [batch_size*2 (* world_size), proj_dim]
+
+        # All samples, filling diagonal to remove identity similarity ((2N)^2 - 2N)
+        all_sim = (feat @ feat.T).fill_diagonal_(0)
+        # all_sim: [batch_size*2 (* world_size), batch_size*2 (* world_size)]
+        all_ = torch.exp(all_sim / temp).sum(-1)
+        # all_: [batch_size*2 (* world_size)]
+
+        # Positive samples (2N)
+        pos_sim = (feat1 * feat2).sum(-1)
+        # pos_sim: [batch_size (* world_size)]
+        pos = torch.exp(pos_sim / temp)
+        # Following all samples, compute positive similarity twice
+        pos = torch.cat((pos, pos))
+        # pos: [batch_size*2 (* world_size)]
+
+        loss = -torch.log(pos / (all_ + eps)).mean()
+
+        return loss
+
+    def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.1):
+        loss_recon = self.pos_recon_loss(batch_pos_embed_pred, vis_ids)
+        loss_clr = self.info_nce_loss(features, temp)
+        return loss_recon, loss_clr
 
     def forward(self, img, mask_ratio=0.75, temp=0.01):
         # img: [batch_size*2, in_chans, height, weight]
@@ -187,8 +235,8 @@ class MaskedPosReconCLRViT(nn.Module):
         # pos_pred: [batch_size*2, seq_len * mask_ratio, embed_dim]
         feat = self.proj_head(latent[:, 0, :])
         # reps: [batch_size*2, proj_dim]
-        loss_recon, loss_clr, acc_clr = self.forward_loss(pos_pred, vis_ids, feat, temp)
-        return latent, pos_pred, feat, loss_recon, loss_clr, acc_clr
+        loss_recon, loss_clr = self.forward_loss(pos_pred, vis_ids, feat, temp)
+        return latent, pos_pred, feat, loss_recon, loss_clr
 
 
 class ShuffledVisionTransformer(VisionTransformer):
-- 
cgit v1.2.3