diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 12:56:14 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 12:56:14 +0800 |
commit | 75dc76a0c0a32c6921fe27c3ce164ed4c16b159c (patch) | |
tree | 9ddf1cfbff6e55b1594ef12af4fc3ff5c76d523f /supervised/baseline.py | |
parent | 832c162df2d8bbc49b6df204849dec98d157cefb (diff) |
Simplify logging method
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 24 |
1 files changed, 7 insertions, 17 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 3347cf6..31e8b33 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -201,16 +201,10 @@ def configure_optimizer(args, model): def load_checkpoint(args, model, optimizer): checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt') checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - restore_log = { - 'epoch': checkpoint['epoch'], - 'train_loss': checkpoint['train_loss'], - 'test_loss': checkpoint['train_loss'], - 'test_accuracy': checkpoint['test_accuracy'] - } + model.load_state_dict(checkpoint.pop('model_state_dict')) + optimizer.load_state_dict(checkpoint.pop('optimizer_state_dict')) - return restore_log + return checkpoint def configure_scheduler(args, optimizer): @@ -309,14 +303,10 @@ def save_checkpoint(args, epoch_log, model, optimizer): if not os.path.exists(args.checkpoint_root): os.makedirs(args.checkpoint_root) - epoch = epoch_log['epoch'] - torch.save({'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'train_loss': epoch_log['train_loss'], - 'test_loss': epoch_log['test_loss'], - 'test_accuracy': epoch_log['test_accuracy'], - }, os.path.join(args.checkpoint_root, f'{epoch:04d}.pt')) + torch.save(epoch_log | { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, os.path.join(args.checkpoint_root, f"{epoch_log['epoch']:04d}.pt")) if __name__ == '__main__': |