diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 15:44:43 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 15:44:43 +0800 |
commit | 3c12e8de3e94ecbb1047b225e5a1d814e7245b71 (patch) | |
tree | ab3b305834a9d743de3b74e7500af2fce8b1887b /supervised | |
parent | 2f06ac98982323c3775faba1f5f64b52b5586b70 (diff) |
Add checkpoint restore support
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 9a83079..bc4128e 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -21,6 +21,7 @@ TENSORBOARD_PATH = os.path.join('runs', CODENAME) CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME) BATCH_SIZE = 256 +RESTORE_EPOCH = 0 N_EPOCHS = 1000 WARMUP_EPOCHS = 10 N_WORKERS = 2 @@ -121,19 +122,29 @@ param_groups = [ }, ] optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) + +if RESTORE_EPOCH > 0: + checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt') + checkpoint = torch.load(checkpoint_path) + resnet.load_state_dict(checkpoint['resnet_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + print(f'[RESTORED][{RESTORE_EPOCH}/{N_EPOCHS}]\t' + f'Train loss: {checkpoint["train_loss"]:.4f}\t' + f'Test loss: {checkpoint["test_loss"]:.4f}') + scheduler = LinearWarmupAndCosineAnneal( optimizer, WARMUP_EPOCHS / N_EPOCHS, N_EPOCHS * num_train_batches, - last_epoch=-1, + last_epoch=RESTORE_EPOCH * num_train_batches - 1 ) optimizer = LARS(optimizer) writer = SummaryWriter(TENSORBOARD_PATH) -train_iters = 0 -test_iters = 0 -for epoch in range(N_EPOCHS): +curr_train_iters = RESTORE_EPOCH * num_train_batches +curr_test_iters = RESTORE_EPOCH * num_test_batches +for epoch in range(RESTORE_EPOCH, N_EPOCHS): train_loss = 0 training_progress = tqdm( enumerate(train_loader), desc='Train loss: ', total=num_train_batches @@ -153,8 +164,8 @@ for epoch in range(N_EPOCHS): train_loss += loss.item() train_loss_mean = train_loss / (batch + 1) training_progress.set_description(f'Train loss: {train_loss_mean:.4f}') - writer.add_scalar('Loss/train', loss, train_iters + 1) - train_iters += 1 + writer.add_scalar('Loss/train', loss, curr_train_iters + 1) + curr_train_iters += 1 test_loss = 0 test_acc = 0 @@ -176,8 +187,8 @@ for epoch in range(N_EPOCHS): test_progress.set_description(f'Test loss: {test_loss_mean:.4f}') test_acc += (prediction == targets).float().mean() test_acc_mean = test_acc / (batch + 1) - writer.add_scalar('Loss/test', loss, test_iters + 1) - test_iters += 1 + writer.add_scalar('Loss/test', loss, curr_test_iters + 1) + curr_test_iters += 1 train_loss_mean = train_loss / num_train_batches test_loss_mean = test_loss / num_test_batches |