diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 92d8d30..8a1b567 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -10,11 +10,11 @@ from torchvision.transforms import transforms, InterpolationMode from tqdm import tqdm from optimizers import LARS -from schedulers import LinearWarmupAndCosineAnneal +from schedulers import LinearWarmupAndCosineAnneal, LinearLR from supervised.datautils import color_distortion from supervised.models import CIFAR10ResNet50 -CODENAME = 'cifar10-resnet50-aug-lars-sched' +CODENAME = 'cifar10-resnet50-aug-lars-warmup' DATASET_ROOT = 'dataset' TENSORBOARD_PATH = os.path.join('runs', CODENAME) CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME) @@ -27,6 +27,7 @@ N_WORKERS = 2 SEED = 0 OPTIM = 'lars' +SCHED = 'warmup-anneal' LR = 1 MOMENTUM = 0.9 WEIGHT_DECAY = 1e-6 @@ -100,12 +101,24 @@ if RESTORE_EPOCH > 0: f'Train loss: {checkpoint["train_loss"]:.4f}\t' f'Test loss: {checkpoint["test_loss"]:.4f}') -scheduler = LinearWarmupAndCosineAnneal( - optimizer, - WARMUP_EPOCHS / N_EPOCHS, - N_EPOCHS * num_train_batches, - last_epoch=RESTORE_EPOCH * num_train_batches - 1 -) +if SCHED == 'warmup-anneal': + scheduler = LinearWarmupAndCosineAnneal( + optimizer, + warm_up=WARMUP_EPOCHS / N_EPOCHS, + T_max=N_EPOCHS * num_train_batches, + last_epoch=RESTORE_EPOCH * num_train_batches - 1 + ) +elif SCHED == 'linear': + scheduler = LinearLR( + optimizer, + num_epochs=N_EPOCHS * num_train_batches, + last_epoch=RESTORE_EPOCH * num_train_batches - 1 + ) +elif SCHED is None or SCHED == '' or SCHED == 'const': + scheduler = None +else: + raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.") + if OPTIM == 'lars': optimizer = LARS(optimizer) @@ -130,7 +143,8 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS): loss = criterion(output, targets) loss.backward() optimizer.step() - scheduler.step() + if SCHED: + scheduler.step() train_loss += loss.item() train_loss_mean = train_loss / (batch + 1) |