diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 31e8b33..23272c3 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -208,18 +208,20 @@ def load_checkpoint(args, model, optimizer): def configure_scheduler(args, optimizer): + n_iters = args.n_epochs * args.num_train_batches + last_iter = args.restore_epoch * args.num_train_batches - 1 if args.sched == 'warmup-anneal': scheduler = LinearWarmupAndCosineAnneal( optimizer, warm_up=args.warmup_epochs / args.n_epochs, - T_max=args.n_epochs * args.num_train_batches, - last_epoch=args.restore_epoch * args.num_train_batches - 1 + T_max=n_iters, + last_epoch=last_iter ) elif args.sched == 'linear': scheduler = LinearLR( optimizer, - num_epochs=args.n_epochs * args.num_train_batches, - last_epoch=args.restore_epoch * args.num_train_batches - 1 + num_epochs=n_iters, + last_epoch=last_iter ) elif args.sched is None or args.sched == '' or args.sched == 'const': scheduler = None @@ -236,7 +238,7 @@ def wrap_lars(args, optimizer): return optimizer -def train(args, train_loader, model, loss_fn, optimizer, scheduler): +def train(args, train_loader, model, loss_fn, optimizer): model.train() for batch, (images, targets) in enumerate(train_loader): images, targets = images.to(args.device), targets.to(args.device) @@ -245,8 +247,6 @@ def train(args, train_loader, model, loss_fn, optimizer, scheduler): loss = loss_fn(output, targets) loss.backward() optimizer.step() - if args.sched: - scheduler.step() yield batch, loss.item() @@ -329,11 +329,17 @@ if __name__ == '__main__': train_loss = torch.zeros(args.num_train_batches, device=args.device) test_loss = torch.zeros(args.num_test_batches, device=args.device) test_accuracy = torch.zeros(args.num_test_batches, device=args.device) - for batch, loss in train(args, train_loader, resnet, xent, optimizer, scheduler): + for batch, loss in train(args, train_loader, resnet, xent, optimizer): train_loss[batch] = loss batch_logger(args, writer, batch, epoch, loss, optimizer.param_groups[0]['lr']) + if scheduler and batch != args.num_train_batches - 1: + scheduler.step() for batch, loss, accuracy in eval(args, test_loader, resnet, xent): test_loss[batch] = loss test_accuracy[batch] = accuracy epoch_log = epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy) save_checkpoint(args, epoch_log, resnet, optimizer) + # Step after save checkpoint, otherwise the schedular + # will one iter ahead after restore + if scheduler: + scheduler.step() |