aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 18:01:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 18:01:43 +0800
commit35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (patch)
tree1d89b9100f7f6c49c990512d093233a74cad6989
parent5869d0248fa958acd3447e6bffa8761b91e8e921 (diff)
Add linear and no scheduler option
-rw-r--r--supervised/baseline.py32
-rw-r--r--supervised/schedulers.py14
2 files changed, 37 insertions, 9 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 92d8d30..8a1b567 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -10,11 +10,11 @@ from torchvision.transforms import transforms, InterpolationMode
from tqdm import tqdm
from optimizers import LARS
-from schedulers import LinearWarmupAndCosineAnneal
+from schedulers import LinearWarmupAndCosineAnneal, LinearLR
from supervised.datautils import color_distortion
from supervised.models import CIFAR10ResNet50
-CODENAME = 'cifar10-resnet50-aug-lars-sched'
+CODENAME = 'cifar10-resnet50-aug-lars-warmup'
DATASET_ROOT = 'dataset'
TENSORBOARD_PATH = os.path.join('runs', CODENAME)
CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
@@ -27,6 +27,7 @@ N_WORKERS = 2
SEED = 0
OPTIM = 'lars'
+SCHED = 'warmup-anneal'
LR = 1
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
@@ -100,12 +101,24 @@ if RESTORE_EPOCH > 0:
f'Train loss: {checkpoint["train_loss"]:.4f}\t'
f'Test loss: {checkpoint["test_loss"]:.4f}')
-scheduler = LinearWarmupAndCosineAnneal(
- optimizer,
- WARMUP_EPOCHS / N_EPOCHS,
- N_EPOCHS * num_train_batches,
- last_epoch=RESTORE_EPOCH * num_train_batches - 1
-)
+if SCHED == 'warmup-anneal':
+ scheduler = LinearWarmupAndCosineAnneal(
+ optimizer,
+ warm_up=WARMUP_EPOCHS / N_EPOCHS,
+ T_max=N_EPOCHS * num_train_batches,
+ last_epoch=RESTORE_EPOCH * num_train_batches - 1
+ )
+elif SCHED == 'linear':
+ scheduler = LinearLR(
+ optimizer,
+ num_epochs=N_EPOCHS * num_train_batches,
+ last_epoch=RESTORE_EPOCH * num_train_batches - 1
+ )
+elif SCHED is None or SCHED == '' or SCHED == 'const':
+ scheduler = None
+else:
+ raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.")
+
if OPTIM == 'lars':
optimizer = LARS(optimizer)
@@ -130,7 +143,8 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS):
loss = criterion(output, targets)
loss.backward()
optimizer.step()
- scheduler.step()
+ if SCHED:
+ scheduler.step()
train_loss += loss.item()
train_loss_mean = train_loss / (batch + 1)
diff --git a/supervised/schedulers.py b/supervised/schedulers.py
index 828e547..7580bf3 100644
--- a/supervised/schedulers.py
+++ b/supervised/schedulers.py
@@ -4,6 +4,20 @@ import numpy as np
import torch
+class LinearLR(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(self, optimizer, num_epochs, last_epoch=-1):
+ self.num_epochs = max(num_epochs, 1)
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ res = []
+ for lr in self.base_lrs:
+ res.append(np.maximum(lr * np.minimum(
+ -self.last_epoch * 1. / self.num_epochs + 1., 1.
+ ), 0.))
+ return res
+
+
class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
self.warm_up = int(warm_up * T_max)