diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index bc4128e..63d3416 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -25,10 +25,12 @@ RESTORE_EPOCH = 0 N_EPOCHS = 1000 WARMUP_EPOCHS = 10 N_WORKERS = 2 +SEED = 0 + +OPTIM = 'lars' LR = 1 MOMENTUM = 0.9 WEIGHT_DECAY = 1e-6 -SEED = 0 if not os.path.exists(CHECKPOINT_PATH): os.makedirs(CHECKPOINT_PATH) @@ -103,7 +105,9 @@ criterion = CrossEntropyLoss() def exclude_from_wd_and_adaptation(name): - if 'bn' in name or 'bias' in name: + if 'bn' in name: + return True + if OPTIM == 'lars' and 'bias' in name: return True @@ -121,8 +125,14 @@ param_groups = [ 'layer_adaptation': False, }, ] -optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) - +if OPTIM == 'adam': + optimizer = torch.optim.Adam(param_groups, lr=LR, betas=(MOMENTUM, 0.999)) +elif OPTIM == 'sdg' or OPTIM == 'lars': + optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) +else: + raise NotImplementedError(f"Optimizer '{OPTIM}' is not implemented.") + +# Restore checkpoint if RESTORE_EPOCH > 0: checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt') checkpoint = torch.load(checkpoint_path) @@ -138,7 +148,8 @@ scheduler = LinearWarmupAndCosineAnneal( N_EPOCHS * num_train_batches, last_epoch=RESTORE_EPOCH * num_train_batches - 1 ) -optimizer = LARS(optimizer) +if OPTIM == 'lars': + optimizer = LARS(optimizer) writer = SummaryWriter(TENSORBOARD_PATH) |