From 75dc76a0c0a32c6921fe27c3ce164ed4c16b159c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 17 Mar 2022 12:56:14 +0800 Subject: Simplify logging method --- supervised/baseline.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index 3347cf6..31e8b33 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -201,16 +201,10 @@ def configure_optimizer(args, model): def load_checkpoint(args, model, optimizer): checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt') checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - restore_log = { - 'epoch': checkpoint['epoch'], - 'train_loss': checkpoint['train_loss'], - 'test_loss': checkpoint['train_loss'], - 'test_accuracy': checkpoint['test_accuracy'] - } + model.load_state_dict(checkpoint.pop('model_state_dict')) + optimizer.load_state_dict(checkpoint.pop('optimizer_state_dict')) - return restore_log + return checkpoint def configure_scheduler(args, optimizer): @@ -309,14 +303,10 @@ def save_checkpoint(args, epoch_log, model, optimizer): if not os.path.exists(args.checkpoint_root): os.makedirs(args.checkpoint_root) - epoch = epoch_log['epoch'] - torch.save({'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'train_loss': epoch_log['train_loss'], - 'test_loss': epoch_log['test_loss'], - 'test_accuracy': epoch_log['test_accuracy'], - }, os.path.join(args.checkpoint_root, f'{epoch:04d}.pt')) + torch.save(epoch_log | { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, os.path.join(args.checkpoint_root, f"{epoch_log['epoch']:04d}.pt")) if __name__ == '__main__': -- cgit v1.2.3