aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 15:44:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 15:44:43 +0800
commit3c12e8de3e94ecbb1047b225e5a1d814e7245b71 (patch)
treeab3b305834a9d743de3b74e7500af2fce8b1887b /supervised/baseline.py
parent2f06ac98982323c3775faba1f5f64b52b5586b70 (diff)
Add checkpoint restore support
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 9a83079..bc4128e 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -21,6 +21,7 @@ TENSORBOARD_PATH = os.path.join('runs', CODENAME)
CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
BATCH_SIZE = 256
+RESTORE_EPOCH = 0
N_EPOCHS = 1000
WARMUP_EPOCHS = 10
N_WORKERS = 2
@@ -121,19 +122,29 @@ param_groups = [
},
]
optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
+
+if RESTORE_EPOCH > 0:
+ checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt')
+ checkpoint = torch.load(checkpoint_path)
+ resnet.load_state_dict(checkpoint['resnet_state_dict'])
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ print(f'[RESTORED][{RESTORE_EPOCH}/{N_EPOCHS}]\t'
+ f'Train loss: {checkpoint["train_loss"]:.4f}\t'
+ f'Test loss: {checkpoint["test_loss"]:.4f}')
+
scheduler = LinearWarmupAndCosineAnneal(
optimizer,
WARMUP_EPOCHS / N_EPOCHS,
N_EPOCHS * num_train_batches,
- last_epoch=-1,
+ last_epoch=RESTORE_EPOCH * num_train_batches - 1
)
optimizer = LARS(optimizer)
writer = SummaryWriter(TENSORBOARD_PATH)
-train_iters = 0
-test_iters = 0
-for epoch in range(N_EPOCHS):
+curr_train_iters = RESTORE_EPOCH * num_train_batches
+curr_test_iters = RESTORE_EPOCH * num_test_batches
+for epoch in range(RESTORE_EPOCH, N_EPOCHS):
train_loss = 0
training_progress = tqdm(
enumerate(train_loader), desc='Train loss: ', total=num_train_batches
@@ -153,8 +164,8 @@ for epoch in range(N_EPOCHS):
train_loss += loss.item()
train_loss_mean = train_loss / (batch + 1)
training_progress.set_description(f'Train loss: {train_loss_mean:.4f}')
- writer.add_scalar('Loss/train', loss, train_iters + 1)
- train_iters += 1
+ writer.add_scalar('Loss/train', loss, curr_train_iters + 1)
+ curr_train_iters += 1
test_loss = 0
test_acc = 0
@@ -176,8 +187,8 @@ for epoch in range(N_EPOCHS):
test_progress.set_description(f'Test loss: {test_loss_mean:.4f}')
test_acc += (prediction == targets).float().mean()
test_acc_mean = test_acc / (batch + 1)
- writer.add_scalar('Loss/test', loss, test_iters + 1)
- test_iters += 1
+ writer.add_scalar('Loss/test', loss, curr_test_iters + 1)
+ curr_test_iters += 1
train_loss_mean = train_loss / num_train_batches
test_loss_mean = test_loss / num_test_batches