diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 12:33:14 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 12:33:14 +0800 |
commit | 832c162df2d8bbc49b6df204849dec98d157cefb (patch) | |
tree | c6eb4ed3567000877a01690c47b30fff9ebaca41 | |
parent | 9327a98bcdac624313763ec0ecdc767f99d8f271 (diff) |
Fix pass args to loggers
-rw-r--r-- | supervised/baseline.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 96017e5..3347cf6 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -271,7 +271,7 @@ def eval(args, test_loader, model, loss_fn): @training_log -def batch_logger(writer, batch, epoch, loss, lr): +def batch_logger(args, writer, batch, epoch, loss, lr): global_batch = epoch * args.num_train_batches + batch writer.add_scalar('Batch loss/train', loss, global_batch + 1) writer.add_scalar('Batch lr/train', lr, global_batch + 1) @@ -288,7 +288,7 @@ def batch_logger(writer, batch, epoch, loss, lr): @training_log -def epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy): +def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): train_loss_mean = train_loss.mean().item() test_loss_mean = test_loss.mean().item() test_accuracy_mean = test_accuracy.mean().item() @@ -305,7 +305,7 @@ def epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy): } -def save_checkpoint(epoch_log, model, optimizer): +def save_checkpoint(args, epoch_log, model, optimizer): if not os.path.exists(args.checkpoint_root): os.makedirs(args.checkpoint_root) @@ -341,9 +341,9 @@ if __name__ == '__main__': test_accuracy = torch.zeros(args.num_test_batches, device=args.device) for batch, loss in train(args, train_loader, resnet, xent, optimizer, scheduler): train_loss[batch] = loss - batch_logger(writer, batch, epoch, loss, optimizer.param_groups[0]['lr']) + batch_logger(args, writer, batch, epoch, loss, optimizer.param_groups[0]['lr']) for batch, loss, accuracy in eval(args, test_loader, resnet, xent): test_loss[batch] = loss test_accuracy[batch] = accuracy - epoch_log = epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy) - save_checkpoint(epoch_log, resnet, optimizer) + epoch_log = epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy) + save_checkpoint(args, epoch_log, resnet, optimizer) |