diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 63d3416..d5671b1 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -32,9 +32,6 @@ LR = 1 MOMENTUM = 0.9 WEIGHT_DECAY = 1e-6 -if not os.path.exists(CHECKPOINT_PATH): - os.makedirs(CHECKPOINT_PATH) - random.seed(SEED) torch.manual_seed(SEED) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -151,6 +148,8 @@ scheduler = LinearWarmupAndCosineAnneal( if OPTIM == 'lars': optimizer = LARS(optimizer) +if not os.path.exists(CHECKPOINT_PATH): + os.makedirs(CHECKPOINT_PATH) writer = SummaryWriter(TENSORBOARD_PATH) curr_train_iters = RESTORE_EPOCH * num_train_batches |