blob: 828e547577921f710f97f1739081229b2f26caca (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
import warnings
import numpy as np
import torch
class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
self.warm_up = int(warm_up * T_max)
self.T_max = T_max - self.warm_up
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
if self.last_epoch == 0:
return [lr / (self.warm_up + 1) for lr in self.base_lrs]
elif self.last_epoch <= self.warm_up:
c = (self.last_epoch + 1) / self.last_epoch
return [group['lr'] * c for group in self.optimizer.param_groups]
else:
# ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
le = self.last_epoch - self.warm_up
return [(1 + np.cos(np.pi * le / self.T_max)) /
(1 + np.cos(np.pi * (le - 1) / self.T_max)) *
group['lr']
for group in self.optimizer.param_groups]
|