diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-14 17:59:05 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-14 17:59:05 +0800 |
commit | 0e466574d07b2f018d99752bcd799ce8ccfd8a96 (patch) | |
tree | 9e88e739f8edd1ec4ade4f172ef8c6ee0064b8dd | |
parent | 957a2a46e7725184776c3c72860e8215164cc4ef (diff) |
Fix a feature gathering problem in InfoNCE lossddp
-rw-r--r-- | libs/criteria.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/libs/criteria.py b/libs/criteria.py index 7d367c1..93fc4d1 100644 --- a/libs/criteria.py +++ b/libs/criteria.py @@ -9,17 +9,16 @@ class InfoNCELoss(nn.Module): def __init__(self, temp=0.01): super().__init__() self.temp = temp + self.local_feat_norm = None - @staticmethod - def _norm_and_stack(feat: Tensor) -> Tensor: - local_feat_norm = F.normalize(feat) - local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2)) - - return local_feat_norm_stack + def get_local_feat_norm(self): + return self.local_feat_norm def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]: + local_feat_norm = F.normalize(feature) + self.local_feat_norm = torch.stack(local_feat_norm.chunk(2)) feat_norm = torch.cat([ - rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,)) + rpc.rpc_sync(f"worker{i}", self.get_local_feat_norm) for i in range(dist.get_world_size()) ], dim=1) bz = feat_norm.size(1) |