diff options
Diffstat (limited to 'supervised/schedulers.py')
-rw-r--r-- | supervised/schedulers.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/supervised/schedulers.py b/supervised/schedulers.py index 828e547..7580bf3 100644 --- a/supervised/schedulers.py +++ b/supervised/schedulers.py @@ -4,6 +4,20 @@ import numpy as np import torch +class LinearLR(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, num_epochs, last_epoch=-1): + self.num_epochs = max(num_epochs, 1) + super().__init__(optimizer, last_epoch) + + def get_lr(self): + res = [] + for lr in self.base_lrs: + res.append(np.maximum(lr * np.minimum( + -self.last_epoch * 1. / self.num_epochs + 1., 1. + ), 0.)) + return res + + class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): self.warm_up = int(warm_up * T_max) |