aboutsummaryrefslogtreecommitdiff
path: root/supervised/lars_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/lars_optimizer.py')
-rw-r--r--supervised/lars_optimizer.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/supervised/lars_optimizer.py b/supervised/lars_optimizer.py
new file mode 100644
index 0000000..1904e8d
--- /dev/null
+++ b/supervised/lars_optimizer.py
@@ -0,0 +1,70 @@
+import torch
+
+
+class LARS(object):
+ """
+ Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
+ Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
+
+ Args:
+ optimizer: Pytorch optimizer to wrap and modify learning rate for.
+ trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
+ """
+
+ def __init__(self,
+ optimizer,
+ trust_coefficient=0.001,
+ ):
+ self.param_groups = optimizer.param_groups
+ self.optim = optimizer
+ self.trust_coefficient = trust_coefficient
+
+ def __getstate__(self):
+ return self.optim.__getstate__()
+
+ def __setstate__(self, state):
+ self.optim.__setstate__(state)
+
+ def __repr__(self):
+ return self.optim.__repr__()
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ self.optim.zero_grad()
+
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)
+
+ def step(self):
+ with torch.no_grad():
+ weight_decays = []
+ for group in self.optim.param_groups:
+ # absorb weight decay control from optimizer
+ weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
+ weight_decays.append(weight_decay)
+ group['weight_decay'] = 0
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ if weight_decay != 0:
+ p.grad.data += weight_decay * p.data
+
+ param_norm = torch.norm(p.data)
+ grad_norm = torch.norm(p.grad.data)
+ adaptive_lr = 1.
+
+ if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
+ adaptive_lr = self.trust_coefficient * param_norm / grad_norm
+
+ p.grad.data *= adaptive_lr
+
+ self.optim.step()
+ # return weight decay control to optimizer
+ for i, group in enumerate(self.optim.param_groups):
+ group['weight_decay'] = weight_decays[i]