aboutsummaryrefslogtreecommitdiff
path: root/supervised/schedulers.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 18:01:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 18:01:43 +0800
commit35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (patch)
tree1d89b9100f7f6c49c990512d093233a74cad6989 /supervised/schedulers.py
parent5869d0248fa958acd3447e6bffa8761b91e8e921 (diff)
Add linear and no scheduler option
Diffstat (limited to 'supervised/schedulers.py')
-rw-r--r--supervised/schedulers.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/supervised/schedulers.py b/supervised/schedulers.py
index 828e547..7580bf3 100644
--- a/supervised/schedulers.py
+++ b/supervised/schedulers.py
@@ -4,6 +4,20 @@ import numpy as np
import torch
+class LinearLR(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(self, optimizer, num_epochs, last_epoch=-1):
+ self.num_epochs = max(num_epochs, 1)
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ res = []
+ for lr in self.base_lrs:
+ res.append(np.maximum(lr * np.minimum(
+ -self.last_epoch * 1. / self.num_epochs + 1., 1.
+ ), 0.))
+ return res
+
+
class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
self.warm_up = int(warm_up * T_max)