From 0e466574d07b2f018d99752bcd799ce8ccfd8a96 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 14 Aug 2022 17:59:05 +0800 Subject: Fix a feature gathering problem in InfoNCE loss --- libs/criteria.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'libs/criteria.py') diff --git a/libs/criteria.py b/libs/criteria.py index 7d367c1..93fc4d1 100644 --- a/libs/criteria.py +++ b/libs/criteria.py @@ -9,17 +9,16 @@ class InfoNCELoss(nn.Module): def __init__(self, temp=0.01): super().__init__() self.temp = temp + self.local_feat_norm = None - @staticmethod - def _norm_and_stack(feat: Tensor) -> Tensor: - local_feat_norm = F.normalize(feat) - local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2)) - - return local_feat_norm_stack + def get_local_feat_norm(self): + return self.local_feat_norm def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]: + local_feat_norm = F.normalize(feature) + self.local_feat_norm = torch.stack(local_feat_norm.chunk(2)) feat_norm = torch.cat([ - rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,)) + rpc.rpc_sync(f"worker{i}", self.get_local_feat_norm) for i in range(dist.get_world_size()) ], dim=1) bz = feat_norm.size(1) -- cgit v1.2.3