aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 63d3416..d5671b1 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -32,9 +32,6 @@ LR = 1
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
-if not os.path.exists(CHECKPOINT_PATH):
- os.makedirs(CHECKPOINT_PATH)
-
random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -151,6 +148,8 @@ scheduler = LinearWarmupAndCosineAnneal(
if OPTIM == 'lars':
optimizer = LARS(optimizer)
+if not os.path.exists(CHECKPOINT_PATH):
+ os.makedirs(CHECKPOINT_PATH)
writer = SummaryWriter(TENSORBOARD_PATH)
curr_train_iters = RESTORE_EPOCH * num_train_batches