diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 18:01:43 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 18:01:43 +0800 |
commit | 35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (patch) | |
tree | 1d89b9100f7f6c49c990512d093233a74cad6989 | |
parent | 5869d0248fa958acd3447e6bffa8761b91e8e921 (diff) |
Add linear and no scheduler option
-rw-r--r-- | supervised/baseline.py | 32 | ||||
-rw-r--r-- | supervised/schedulers.py | 14 |
2 files changed, 37 insertions, 9 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 92d8d30..8a1b567 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -10,11 +10,11 @@ from torchvision.transforms import transforms, InterpolationMode from tqdm import tqdm from optimizers import LARS -from schedulers import LinearWarmupAndCosineAnneal +from schedulers import LinearWarmupAndCosineAnneal, LinearLR from supervised.datautils import color_distortion from supervised.models import CIFAR10ResNet50 -CODENAME = 'cifar10-resnet50-aug-lars-sched' +CODENAME = 'cifar10-resnet50-aug-lars-warmup' DATASET_ROOT = 'dataset' TENSORBOARD_PATH = os.path.join('runs', CODENAME) CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME) @@ -27,6 +27,7 @@ N_WORKERS = 2 SEED = 0 OPTIM = 'lars' +SCHED = 'warmup-anneal' LR = 1 MOMENTUM = 0.9 WEIGHT_DECAY = 1e-6 @@ -100,12 +101,24 @@ if RESTORE_EPOCH > 0: f'Train loss: {checkpoint["train_loss"]:.4f}\t' f'Test loss: {checkpoint["test_loss"]:.4f}') -scheduler = LinearWarmupAndCosineAnneal( - optimizer, - WARMUP_EPOCHS / N_EPOCHS, - N_EPOCHS * num_train_batches, - last_epoch=RESTORE_EPOCH * num_train_batches - 1 -) +if SCHED == 'warmup-anneal': + scheduler = LinearWarmupAndCosineAnneal( + optimizer, + warm_up=WARMUP_EPOCHS / N_EPOCHS, + T_max=N_EPOCHS * num_train_batches, + last_epoch=RESTORE_EPOCH * num_train_batches - 1 + ) +elif SCHED == 'linear': + scheduler = LinearLR( + optimizer, + num_epochs=N_EPOCHS * num_train_batches, + last_epoch=RESTORE_EPOCH * num_train_batches - 1 + ) +elif SCHED is None or SCHED == '' or SCHED == 'const': + scheduler = None +else: + raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.") + if OPTIM == 'lars': optimizer = LARS(optimizer) @@ -130,7 +143,8 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS): loss = criterion(output, targets) loss.backward() optimizer.step() - scheduler.step() + if SCHED: + scheduler.step() train_loss += loss.item() train_loss_mean = train_loss / (batch + 1) diff --git a/supervised/schedulers.py b/supervised/schedulers.py index 828e547..7580bf3 100644 --- a/supervised/schedulers.py +++ b/supervised/schedulers.py @@ -4,6 +4,20 @@ import numpy as np import torch +class LinearLR(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, num_epochs, last_epoch=-1): + self.num_epochs = max(num_epochs, 1) + super().__init__(optimizer, last_epoch) + + def get_lr(self): + res = [] + for lr in self.base_lrs: + res.append(np.maximum(lr * np.minimum( + -self.last_epoch * 1. / self.num_epochs + 1., 1. + ), 0.)) + return res + + class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): self.warm_up = int(warm_up * T_max) |