aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--supervised/baseline.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 31e8b33..23272c3 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -208,18 +208,20 @@ def load_checkpoint(args, model, optimizer):
def configure_scheduler(args, optimizer):
+ n_iters = args.n_epochs * args.num_train_batches
+ last_iter = args.restore_epoch * args.num_train_batches - 1
if args.sched == 'warmup-anneal':
scheduler = LinearWarmupAndCosineAnneal(
optimizer,
warm_up=args.warmup_epochs / args.n_epochs,
- T_max=args.n_epochs * args.num_train_batches,
- last_epoch=args.restore_epoch * args.num_train_batches - 1
+ T_max=n_iters,
+ last_epoch=last_iter
)
elif args.sched == 'linear':
scheduler = LinearLR(
optimizer,
- num_epochs=args.n_epochs * args.num_train_batches,
- last_epoch=args.restore_epoch * args.num_train_batches - 1
+ num_epochs=n_iters,
+ last_epoch=last_iter
)
elif args.sched is None or args.sched == '' or args.sched == 'const':
scheduler = None
@@ -236,7 +238,7 @@ def wrap_lars(args, optimizer):
return optimizer
-def train(args, train_loader, model, loss_fn, optimizer, scheduler):
+def train(args, train_loader, model, loss_fn, optimizer):
model.train()
for batch, (images, targets) in enumerate(train_loader):
images, targets = images.to(args.device), targets.to(args.device)
@@ -245,8 +247,6 @@ def train(args, train_loader, model, loss_fn, optimizer, scheduler):
loss = loss_fn(output, targets)
loss.backward()
optimizer.step()
- if args.sched:
- scheduler.step()
yield batch, loss.item()
@@ -329,11 +329,17 @@ if __name__ == '__main__':
train_loss = torch.zeros(args.num_train_batches, device=args.device)
test_loss = torch.zeros(args.num_test_batches, device=args.device)
test_accuracy = torch.zeros(args.num_test_batches, device=args.device)
- for batch, loss in train(args, train_loader, resnet, xent, optimizer, scheduler):
+ for batch, loss in train(args, train_loader, resnet, xent, optimizer):
train_loss[batch] = loss
batch_logger(args, writer, batch, epoch, loss, optimizer.param_groups[0]['lr'])
+ if scheduler and batch != args.num_train_batches - 1:
+ scheduler.step()
for batch, loss, accuracy in eval(args, test_loader, resnet, xent):
test_loss[batch] = loss
test_accuracy[batch] = accuracy
epoch_log = epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy)
save_checkpoint(args, epoch_log, resnet, optimizer)
+ # Step after save checkpoint, otherwise the schedular
+ # will one iter ahead after restore
+ if scheduler:
+ scheduler.step()