aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 12:56:14 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 12:56:14 +0800
commit75dc76a0c0a32c6921fe27c3ce164ed4c16b159c (patch)
tree9ddf1cfbff6e55b1594ef12af4fc3ff5c76d523f /supervised
parent832c162df2d8bbc49b6df204849dec98d157cefb (diff)
Simplify logging method
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py24
1 files changed, 7 insertions, 17 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 3347cf6..31e8b33 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -201,16 +201,10 @@ def configure_optimizer(args, model):
def load_checkpoint(args, model, optimizer):
checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt')
checkpoint = torch.load(checkpoint_path)
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- restore_log = {
- 'epoch': checkpoint['epoch'],
- 'train_loss': checkpoint['train_loss'],
- 'test_loss': checkpoint['train_loss'],
- 'test_accuracy': checkpoint['test_accuracy']
- }
+ model.load_state_dict(checkpoint.pop('model_state_dict'))
+ optimizer.load_state_dict(checkpoint.pop('optimizer_state_dict'))
- return restore_log
+ return checkpoint
def configure_scheduler(args, optimizer):
@@ -309,14 +303,10 @@ def save_checkpoint(args, epoch_log, model, optimizer):
if not os.path.exists(args.checkpoint_root):
os.makedirs(args.checkpoint_root)
- epoch = epoch_log['epoch']
- torch.save({'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'train_loss': epoch_log['train_loss'],
- 'test_loss': epoch_log['test_loss'],
- 'test_accuracy': epoch_log['test_accuracy'],
- }, os.path.join(args.checkpoint_root, f'{epoch:04d}.pt'))
+ torch.save(epoch_log | {
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ }, os.path.join(args.checkpoint_root, f"{epoch_log['epoch']:04d}.pt"))
if __name__ == '__main__':