aboutsummaryrefslogtreecommitdiff
path: root/libs
diff options
context:
space:
mode:
Diffstat (limited to 'libs')
-rw-r--r--libs/criteria.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/libs/criteria.py b/libs/criteria.py
index 7d367c1..93fc4d1 100644
--- a/libs/criteria.py
+++ b/libs/criteria.py
@@ -9,17 +9,16 @@ class InfoNCELoss(nn.Module):
def __init__(self, temp=0.01):
super().__init__()
self.temp = temp
+ self.local_feat_norm = None
- @staticmethod
- def _norm_and_stack(feat: Tensor) -> Tensor:
- local_feat_norm = F.normalize(feat)
- local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2))
-
- return local_feat_norm_stack
+ def get_local_feat_norm(self):
+ return self.local_feat_norm
def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]:
+ local_feat_norm = F.normalize(feature)
+ self.local_feat_norm = torch.stack(local_feat_norm.chunk(2))
feat_norm = torch.cat([
- rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,))
+ rpc.rpc_sync(f"worker{i}", self.get_local_feat_norm)
for i in range(dist.get_world_size())
], dim=1)
bz = feat_norm.size(1)