aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-14 17:59:05 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-14 17:59:05 +0800
commit0e466574d07b2f018d99752bcd799ce8ccfd8a96 (patch)
tree9e88e739f8edd1ec4ade4f172ef8c6ee0064b8dd
parent957a2a46e7725184776c3c72860e8215164cc4ef (diff)
Fix a feature gathering problem in InfoNCE lossddp
-rw-r--r--libs/criteria.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/libs/criteria.py b/libs/criteria.py
index 7d367c1..93fc4d1 100644
--- a/libs/criteria.py
+++ b/libs/criteria.py
@@ -9,17 +9,16 @@ class InfoNCELoss(nn.Module):
def __init__(self, temp=0.01):
super().__init__()
self.temp = temp
+ self.local_feat_norm = None
- @staticmethod
- def _norm_and_stack(feat: Tensor) -> Tensor:
- local_feat_norm = F.normalize(feat)
- local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2))
-
- return local_feat_norm_stack
+ def get_local_feat_norm(self):
+ return self.local_feat_norm
def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]:
+ local_feat_norm = F.normalize(feature)
+ self.local_feat_norm = torch.stack(local_feat_norm.chunk(2))
feat_norm = torch.cat([
- rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,))
+ rpc.rpc_sync(f"worker{i}", self.get_local_feat_norm)
for i in range(dist.get_world_size())
], dim=1)
bz = feat_norm.size(1)