aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--supervised/baseline.py21
1 files changed, 16 insertions, 5 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index bc4128e..63d3416 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -25,10 +25,12 @@ RESTORE_EPOCH = 0
N_EPOCHS = 1000
WARMUP_EPOCHS = 10
N_WORKERS = 2
+SEED = 0
+
+OPTIM = 'lars'
LR = 1
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
-SEED = 0
if not os.path.exists(CHECKPOINT_PATH):
os.makedirs(CHECKPOINT_PATH)
@@ -103,7 +105,9 @@ criterion = CrossEntropyLoss()
def exclude_from_wd_and_adaptation(name):
- if 'bn' in name or 'bias' in name:
+ if 'bn' in name:
+ return True
+ if OPTIM == 'lars' and 'bias' in name:
return True
@@ -121,8 +125,14 @@ param_groups = [
'layer_adaptation': False,
},
]
-optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
-
+if OPTIM == 'adam':
+ optimizer = torch.optim.Adam(param_groups, lr=LR, betas=(MOMENTUM, 0.999))
+elif OPTIM == 'sdg' or OPTIM == 'lars':
+ optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
+else:
+ raise NotImplementedError(f"Optimizer '{OPTIM}' is not implemented.")
+
+# Restore checkpoint
if RESTORE_EPOCH > 0:
checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt')
checkpoint = torch.load(checkpoint_path)
@@ -138,7 +148,8 @@ scheduler = LinearWarmupAndCosineAnneal(
N_EPOCHS * num_train_batches,
last_epoch=RESTORE_EPOCH * num_train_batches - 1
)
-optimizer = LARS(optimizer)
+if OPTIM == 'lars':
+ optimizer = LARS(optimizer)
writer = SummaryWriter(TENSORBOARD_PATH)