aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 18:01:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 18:01:43 +0800
commit35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (patch)
tree1d89b9100f7f6c49c990512d093233a74cad6989 /supervised/baseline.py
parent5869d0248fa958acd3447e6bffa8761b91e8e921 (diff)
Add linear and no scheduler option
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py32
1 files changed, 23 insertions, 9 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 92d8d30..8a1b567 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -10,11 +10,11 @@ from torchvision.transforms import transforms, InterpolationMode
from tqdm import tqdm
from optimizers import LARS
-from schedulers import LinearWarmupAndCosineAnneal
+from schedulers import LinearWarmupAndCosineAnneal, LinearLR
from supervised.datautils import color_distortion
from supervised.models import CIFAR10ResNet50
-CODENAME = 'cifar10-resnet50-aug-lars-sched'
+CODENAME = 'cifar10-resnet50-aug-lars-warmup'
DATASET_ROOT = 'dataset'
TENSORBOARD_PATH = os.path.join('runs', CODENAME)
CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
@@ -27,6 +27,7 @@ N_WORKERS = 2
SEED = 0
OPTIM = 'lars'
+SCHED = 'warmup-anneal'
LR = 1
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
@@ -100,12 +101,24 @@ if RESTORE_EPOCH > 0:
f'Train loss: {checkpoint["train_loss"]:.4f}\t'
f'Test loss: {checkpoint["test_loss"]:.4f}')
-scheduler = LinearWarmupAndCosineAnneal(
- optimizer,
- WARMUP_EPOCHS / N_EPOCHS,
- N_EPOCHS * num_train_batches,
- last_epoch=RESTORE_EPOCH * num_train_batches - 1
-)
+if SCHED == 'warmup-anneal':
+ scheduler = LinearWarmupAndCosineAnneal(
+ optimizer,
+ warm_up=WARMUP_EPOCHS / N_EPOCHS,
+ T_max=N_EPOCHS * num_train_batches,
+ last_epoch=RESTORE_EPOCH * num_train_batches - 1
+ )
+elif SCHED == 'linear':
+ scheduler = LinearLR(
+ optimizer,
+ num_epochs=N_EPOCHS * num_train_batches,
+ last_epoch=RESTORE_EPOCH * num_train_batches - 1
+ )
+elif SCHED is None or SCHED == '' or SCHED == 'const':
+ scheduler = None
+else:
+ raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.")
+
if OPTIM == 'lars':
optimizer = LARS(optimizer)
@@ -130,7 +143,8 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS):
loss = criterion(output, targets)
loss.backward()
optimizer.step()
- scheduler.step()
+ if SCHED:
+ scheduler.step()
train_loss += loss.item()
train_loss_mean = train_loss / (batch + 1)