aboutsummaryrefslogtreecommitdiff
path: root/libs/optimizers.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 20:10:42 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 20:10:42 +0800
commit9b2c25d3c927b7533e5d7d9665b67962e4c6934b (patch)
treeb4a80004ffdfe8165f5f8d077afb8b054d4472ce /libs/optimizers.py
parent568569c764ffdd73cd660434df50d30d26203f63 (diff)
Move some utils to libs directory
Diffstat (limited to 'libs/optimizers.py')
-rw-r--r--libs/optimizers.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/libs/optimizers.py b/libs/optimizers.py
new file mode 100644
index 0000000..1904e8d
--- /dev/null
+++ b/libs/optimizers.py
@@ -0,0 +1,70 @@
+import torch
+
+
+class LARS(object):
+ """
+ Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
+ Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
+
+ Args:
+ optimizer: Pytorch optimizer to wrap and modify learning rate for.
+ trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
+ """
+
+ def __init__(self,
+ optimizer,
+ trust_coefficient=0.001,
+ ):
+ self.param_groups = optimizer.param_groups
+ self.optim = optimizer
+ self.trust_coefficient = trust_coefficient
+
+ def __getstate__(self):
+ return self.optim.__getstate__()
+
+ def __setstate__(self, state):
+ self.optim.__setstate__(state)
+
+ def __repr__(self):
+ return self.optim.__repr__()
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ self.optim.zero_grad()
+
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)
+
+ def step(self):
+ with torch.no_grad():
+ weight_decays = []
+ for group in self.optim.param_groups:
+ # absorb weight decay control from optimizer
+ weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
+ weight_decays.append(weight_decay)
+ group['weight_decay'] = 0
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ if weight_decay != 0:
+ p.grad.data += weight_decay * p.data
+
+ param_norm = torch.norm(p.data)
+ grad_norm = torch.norm(p.grad.data)
+ adaptive_lr = 1.
+
+ if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
+ adaptive_lr = self.trust_coefficient * param_norm / grad_norm
+
+ p.grad.data *= adaptive_lr
+
+ self.optim.step()
+ # return weight decay control to optimizer
+ for i, group in enumerate(self.optim.param_groups):
+ group['weight_decay'] = weight_decays[i]