diff options
-rw-r--r-- | supervised/baseline.py | 24 |
1 files changed, 7 insertions, 17 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 3347cf6..31e8b33 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -201,16 +201,10 @@ def configure_optimizer(args, model): def load_checkpoint(args, model, optimizer): checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt') checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - restore_log = { - 'epoch': checkpoint['epoch'], - 'train_loss': checkpoint['train_loss'], - 'test_loss': checkpoint['train_loss'], - 'test_accuracy': checkpoint['test_accuracy'] - } + model.load_state_dict(checkpoint.pop('model_state_dict')) + optimizer.load_state_dict(checkpoint.pop('optimizer_state_dict')) - return restore_log + return checkpoint def configure_scheduler(args, optimizer): @@ -309,14 +303,10 @@ def save_checkpoint(args, epoch_log, model, optimizer): if not os.path.exists(args.checkpoint_root): os.makedirs(args.checkpoint_root) - epoch = epoch_log['epoch'] - torch.save({'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'train_loss': epoch_log['train_loss'], - 'test_loss': epoch_log['test_loss'], - 'test_accuracy': epoch_log['test_accuracy'], - }, os.path.join(args.checkpoint_root, f'{epoch:04d}.pt')) + torch.save(epoch_log | { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, os.path.join(args.checkpoint_root, f"{epoch_log['epoch']:04d}.pt")) if __name__ == '__main__': |