From 90ff2e532e11fe0e5948adc1ef34ddf29924c8ed Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 16 Mar 2022 16:44:09 +0800 Subject: Add Adam and SDG optimizers --- supervised/baseline.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index bc4128e..63d3416 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -25,10 +25,12 @@ RESTORE_EPOCH = 0 N_EPOCHS = 1000 WARMUP_EPOCHS = 10 N_WORKERS = 2 +SEED = 0 + +OPTIM = 'lars' LR = 1 MOMENTUM = 0.9 WEIGHT_DECAY = 1e-6 -SEED = 0 if not os.path.exists(CHECKPOINT_PATH): os.makedirs(CHECKPOINT_PATH) @@ -103,7 +105,9 @@ criterion = CrossEntropyLoss() def exclude_from_wd_and_adaptation(name): - if 'bn' in name or 'bias' in name: + if 'bn' in name: + return True + if OPTIM == 'lars' and 'bias' in name: return True @@ -121,8 +125,14 @@ param_groups = [ 'layer_adaptation': False, }, ] -optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) - +if OPTIM == 'adam': + optimizer = torch.optim.Adam(param_groups, lr=LR, betas=(MOMENTUM, 0.999)) +elif OPTIM == 'sdg' or OPTIM == 'lars': + optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) +else: + raise NotImplementedError(f"Optimizer '{OPTIM}' is not implemented.") + +# Restore checkpoint if RESTORE_EPOCH > 0: checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt') checkpoint = torch.load(checkpoint_path) @@ -138,7 +148,8 @@ scheduler = LinearWarmupAndCosineAnneal( N_EPOCHS * num_train_batches, last_epoch=RESTORE_EPOCH * num_train_batches - 1 ) -optimizer = LARS(optimizer) +if OPTIM == 'lars': + optimizer = LARS(optimizer) writer = SummaryWriter(TENSORBOARD_PATH) -- cgit v1.2.3