From 75dc76a0c0a32c6921fe27c3ce164ed4c16b159c Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Thu, 17 Mar 2022 12:56:14 +0800
Subject: Simplify logging method

---
 supervised/baseline.py | 24 +++++++-----------------
 1 file changed, 7 insertions(+), 17 deletions(-)

(limited to 'supervised')

diff --git a/supervised/baseline.py b/supervised/baseline.py
index 3347cf6..31e8b33 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -201,16 +201,10 @@ def configure_optimizer(args, model):
 def load_checkpoint(args, model, optimizer):
     checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt')
     checkpoint = torch.load(checkpoint_path)
-    model.load_state_dict(checkpoint['model_state_dict'])
-    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-    restore_log = {
-        'epoch': checkpoint['epoch'],
-        'train_loss': checkpoint['train_loss'],
-        'test_loss': checkpoint['train_loss'],
-        'test_accuracy': checkpoint['test_accuracy']
-    }
+    model.load_state_dict(checkpoint.pop('model_state_dict'))
+    optimizer.load_state_dict(checkpoint.pop('optimizer_state_dict'))
 
-    return restore_log
+    return checkpoint
 
 
 def configure_scheduler(args, optimizer):
@@ -309,14 +303,10 @@ def save_checkpoint(args, epoch_log, model, optimizer):
     if not os.path.exists(args.checkpoint_root):
         os.makedirs(args.checkpoint_root)
 
-    epoch = epoch_log['epoch']
-    torch.save({'epoch': epoch,
-                'model_state_dict': model.state_dict(),
-                'optimizer_state_dict': optimizer.state_dict(),
-                'train_loss': epoch_log['train_loss'],
-                'test_loss': epoch_log['test_loss'],
-                'test_accuracy': epoch_log['test_accuracy'],
-                }, os.path.join(args.checkpoint_root, f'{epoch:04d}.pt'))
+    torch.save(epoch_log | {
+        'model_state_dict': model.state_dict(),
+        'optimizer_state_dict': optimizer.state_dict(),
+    }, os.path.join(args.checkpoint_root, f"{epoch_log['epoch']:04d}.pt"))
 
 
 if __name__ == '__main__':
-- 
cgit v1.2.3