aboutsummaryrefslogtreecommitdiff
path: root/posrecon
diff options
context:
space:
mode:
Diffstat (limited to 'posrecon')
-rw-r--r--posrecon/models.py98
1 files changed, 73 insertions, 25 deletions
diff --git a/posrecon/models.py b/posrecon/models.py
index 5f59510..1de435e 100644
--- a/posrecon/models.py
+++ b/posrecon/models.py
@@ -18,6 +18,31 @@ sys.path.insert(0, path)
from simclr.models import SimCLRBase
+class SyncFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, tensor):
+ ctx.batch_size = tensor.shape[0]
+
+ gathered_tensor = [torch.zeros_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+
+ torch.distributed.all_gather(gathered_tensor, tensor)
+ gathered_tensor = torch.cat(gathered_tensor, 0)
+
+ return gathered_tensor
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ torch.distributed.all_reduce(grad_input,
+ op=torch.distributed.ReduceOp.SUM,
+ async_op=False)
+
+ idx_from = torch.distributed.get_rank() * ctx.batch_size
+ idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
+ return grad_input[idx_from:idx_to]
+
+
class MaskedPosReconCLRViT(nn.Module):
"""
Masked contrastive learning Vision Transformer w/ positional reconstruction
@@ -152,32 +177,55 @@ class MaskedPosReconCLRViT(nn.Module):
return pos_embed_pred
- def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.01):
- batch_size, _, _ = batch_pos_embed_pred.size()
+ def pos_recon_loss(self, batch_pos_embed_pred, vis_ids):
+ batch_size = batch_pos_embed_pred.size(0)
+ # self.pos_embed: [1, seq_len, embed_dim]
+ # batch_pos_embed_pred: [batch_size*2, seq_len, embed_dim]
+ # vis_ids: [batch_size*2, seq_len * mask_ratio, embed_dim]
batch_pos_embed_targ = self.pos_embed.expand(batch_size, -1, -1)
+ # batch_pos_embed_targ: [batch_size*2, seq_len, embed_dim]
# Only compute loss on visible patches
visible_pos_embed_targ = batch_pos_embed_targ.gather(1, vis_ids)
- loss_recon = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ)
-
- bz_clr = batch_size // 2
- feat_norm = F.normalize(features)
- feat1_norm, feat2_norm = feat_norm.split(bz_clr)
- # feat1_norm, feat2_norm: [batch_size, proj_dim]
- logits = feat1_norm @ feat2_norm.T
- # logits: [batch_size, batch_size]
- pos_logits_mask = torch.eye(bz_clr, dtype=torch.bool)
- pos_logits = logits[pos_logits_mask].unsqueeze(-1)
- # pos_logits: [batch_size, 1]
- neg_logits = logits[~pos_logits_mask].view(bz_clr, -1)
- # neg_logits: [batch_size, batch_size - 1]
- # Put the positive at first (0-th) and maximize its likelihood
- logits = torch.cat([pos_logits, neg_logits], dim=1)
- # logits: [batch_size, batch_size]
- labels = torch.zeros(bz_clr, dtype=torch.long, device=features.device)
- loss_clr = F.cross_entropy(logits / temp, labels)
- acc_clr = (logits.argmax(dim=1) == labels).float().mean()
-
- return loss_recon, loss_clr, acc_clr
+ # visible_pos_embed_targ: [batch_size*2, seq_len * mask_ratio, embed_dim]
+ loss = F.mse_loss(batch_pos_embed_pred, visible_pos_embed_targ)
+ return loss
+
+ @staticmethod
+ def info_nce_loss(features, temp, eps=1e-6):
+ feat = F.normalize(features)
+ # feat: [batch_size*2, proj_dim]
+ feat = torch.stack(feat.chunk(2), dim=1)
+ # feat: [batch_size, 2, proj_dim]
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ feat = SyncFunction.apply(feat)
+ # feat: [batch_size (* world_size), 2, proj_dim]
+ feat1, feat2 = feat[:, 0, :], feat[:, 1, :]
+ # feat{1,2}: [batch_size (* world_size), proj_dim]
+ feat = torch.cat((feat1, feat2))
+ # feat: [batch_size*2 (* world_size), proj_dim]
+
+ # All samples, filling diagonal to remove identity similarity ((2N)^2 - 2N)
+ all_sim = (feat @ feat.T).fill_diagonal_(0)
+ # all_sim: [batch_size*2 (* world_size), batch_size*2 (* world_size)]
+ all_ = torch.exp(all_sim / temp).sum(-1)
+ # all_: [batch_size*2 (* world_size)]
+
+ # Positive samples (2N)
+ pos_sim = (feat1 * feat2).sum(-1)
+ # pos_sim: [batch_size (* world_size)]
+ pos = torch.exp(pos_sim / temp)
+ # Following all samples, compute positive similarity twice
+ pos = torch.cat((pos, pos))
+ # pos: [batch_size*2 (* world_size)]
+
+ loss = -torch.log(pos / (all_ + eps)).mean()
+
+ return loss
+
+ def forward_loss(self, batch_pos_embed_pred, vis_ids, features, temp=0.1):
+ loss_recon = self.pos_recon_loss(batch_pos_embed_pred, vis_ids)
+ loss_clr = self.info_nce_loss(features, temp)
+ return loss_recon, loss_clr
def forward(self, img, mask_ratio=0.75, temp=0.01):
# img: [batch_size*2, in_chans, height, weight]
@@ -187,8 +235,8 @@ class MaskedPosReconCLRViT(nn.Module):
# pos_pred: [batch_size*2, seq_len * mask_ratio, embed_dim]
feat = self.proj_head(latent[:, 0, :])
# reps: [batch_size*2, proj_dim]
- loss_recon, loss_clr, acc_clr = self.forward_loss(pos_pred, vis_ids, feat, temp)
- return latent, pos_pred, feat, loss_recon, loss_clr, acc_clr
+ loss_recon, loss_clr = self.forward_loss(pos_pred, vis_ids, feat, temp)
+ return latent, pos_pred, feat, loss_recon, loss_clr
class ShuffledVisionTransformer(VisionTransformer):