From 9b2c25d3c927b7533e5d7d9665b67962e4c6934b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 17 Mar 2022 20:10:42 +0800 Subject: Move some utils to libs directory --- libs/optimizers.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 libs/optimizers.py (limited to 'libs/optimizers.py') diff --git a/libs/optimizers.py b/libs/optimizers.py new file mode 100644 index 0000000..1904e8d --- /dev/null +++ b/libs/optimizers.py @@ -0,0 +1,70 @@ +import torch + + +class LARS(object): + """ + Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py + Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py + + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 + """ + + def __init__(self, + optimizer, + trust_coefficient=0.001, + ): + self.param_groups = optimizer.param_groups + self.optim = optimizer + self.trust_coefficient = trust_coefficient + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 + weight_decays.append(weight_decay) + group['weight_decay'] = 0 + for p in group['params']: + if p.grad is None: + continue + + if weight_decay != 0: + p.grad.data += weight_decay * p.data + + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + adaptive_lr = 1. + + if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: + adaptive_lr = self.trust_coefficient * param_norm / grad_norm + + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group['weight_decay'] = weight_decays[i] -- cgit v1.2.3