aboutsummaryrefslogtreecommitdiff
path: root/supervised/lars_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/lars_optimizer.py')
-rw-r--r--supervised/lars_optimizer.py70
1 files changed, 0 insertions, 70 deletions
diff --git a/supervised/lars_optimizer.py b/supervised/lars_optimizer.py
deleted file mode 100644
index 1904e8d..0000000
--- a/supervised/lars_optimizer.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import torch
-
-
-class LARS(object):
- """
- Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
- Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
-
- Args:
- optimizer: Pytorch optimizer to wrap and modify learning rate for.
- trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
- """
-
- def __init__(self,
- optimizer,
- trust_coefficient=0.001,
- ):
- self.param_groups = optimizer.param_groups
- self.optim = optimizer
- self.trust_coefficient = trust_coefficient
-
- def __getstate__(self):
- return self.optim.__getstate__()
-
- def __setstate__(self, state):
- self.optim.__setstate__(state)
-
- def __repr__(self):
- return self.optim.__repr__()
-
- def state_dict(self):
- return self.optim.state_dict()
-
- def load_state_dict(self, state_dict):
- self.optim.load_state_dict(state_dict)
-
- def zero_grad(self):
- self.optim.zero_grad()
-
- def add_param_group(self, param_group):
- self.optim.add_param_group(param_group)
-
- def step(self):
- with torch.no_grad():
- weight_decays = []
- for group in self.optim.param_groups:
- # absorb weight decay control from optimizer
- weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
- weight_decays.append(weight_decay)
- group['weight_decay'] = 0
- for p in group['params']:
- if p.grad is None:
- continue
-
- if weight_decay != 0:
- p.grad.data += weight_decay * p.data
-
- param_norm = torch.norm(p.data)
- grad_norm = torch.norm(p.grad.data)
- adaptive_lr = 1.
-
- if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
- adaptive_lr = self.trust_coefficient * param_norm / grad_norm
-
- p.grad.data *= adaptive_lr
-
- self.optim.step()
- # return weight decay control to optimizer
- for i, group in enumerate(self.optim.param_groups):
- group['weight_decay'] = weight_decays[i]