diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 17:49:51 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 17:49:51 +0800 |
commit | 5869d0248fa958acd3447e6bffa8761b91e8e921 (patch) | |
tree | 4e2c0744400d9204bdfd23c58bafcf534c2119fb /supervised/lars_optimizer.py | |
parent | 608178533e93dc7e6fac6059fa139233ab046b63 (diff) |
Regular refactoring
Diffstat (limited to 'supervised/lars_optimizer.py')
-rw-r--r-- | supervised/lars_optimizer.py | 70 |
1 files changed, 0 insertions, 70 deletions
diff --git a/supervised/lars_optimizer.py b/supervised/lars_optimizer.py deleted file mode 100644 index 1904e8d..0000000 --- a/supervised/lars_optimizer.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch - - -class LARS(object): - """ - Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py - Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py - - Args: - optimizer: Pytorch optimizer to wrap and modify learning rate for. - trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 - """ - - def __init__(self, - optimizer, - trust_coefficient=0.001, - ): - self.param_groups = optimizer.param_groups - self.optim = optimizer - self.trust_coefficient = trust_coefficient - - def __getstate__(self): - return self.optim.__getstate__() - - def __setstate__(self, state): - self.optim.__setstate__(state) - - def __repr__(self): - return self.optim.__repr__() - - def state_dict(self): - return self.optim.state_dict() - - def load_state_dict(self, state_dict): - self.optim.load_state_dict(state_dict) - - def zero_grad(self): - self.optim.zero_grad() - - def add_param_group(self, param_group): - self.optim.add_param_group(param_group) - - def step(self): - with torch.no_grad(): - weight_decays = [] - for group in self.optim.param_groups: - # absorb weight decay control from optimizer - weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 - weight_decays.append(weight_decay) - group['weight_decay'] = 0 - for p in group['params']: - if p.grad is None: - continue - - if weight_decay != 0: - p.grad.data += weight_decay * p.data - - param_norm = torch.norm(p.data) - grad_norm = torch.norm(p.grad.data) - adaptive_lr = 1. - - if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: - adaptive_lr = self.trust_coefficient * param_norm / grad_norm - - p.grad.data *= adaptive_lr - - self.optim.step() - # return weight decay control to optimizer - for i, group in enumerate(self.optim.param_groups): - group['weight_decay'] = weight_decays[i] |