From 832c162df2d8bbc49b6df204849dec98d157cefb Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 17 Mar 2022 12:33:14 +0800 Subject: Fix pass args to loggers --- supervised/baseline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index 96017e5..3347cf6 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -271,7 +271,7 @@ def eval(args, test_loader, model, loss_fn): @training_log -def batch_logger(writer, batch, epoch, loss, lr): +def batch_logger(args, writer, batch, epoch, loss, lr): global_batch = epoch * args.num_train_batches + batch writer.add_scalar('Batch loss/train', loss, global_batch + 1) writer.add_scalar('Batch lr/train', lr, global_batch + 1) @@ -288,7 +288,7 @@ def batch_logger(writer, batch, epoch, loss, lr): @training_log -def epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy): +def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): train_loss_mean = train_loss.mean().item() test_loss_mean = test_loss.mean().item() test_accuracy_mean = test_accuracy.mean().item() @@ -305,7 +305,7 @@ def epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy): } -def save_checkpoint(epoch_log, model, optimizer): +def save_checkpoint(args, epoch_log, model, optimizer): if not os.path.exists(args.checkpoint_root): os.makedirs(args.checkpoint_root) @@ -341,9 +341,9 @@ if __name__ == '__main__': test_accuracy = torch.zeros(args.num_test_batches, device=args.device) for batch, loss in train(args, train_loader, resnet, xent, optimizer, scheduler): train_loss[batch] = loss - batch_logger(writer, batch, epoch, loss, optimizer.param_groups[0]['lr']) + batch_logger(args, writer, batch, epoch, loss, optimizer.param_groups[0]['lr']) for batch, loss, accuracy in eval(args, test_loader, resnet, xent): test_loss[batch] = loss test_accuracy[batch] = accuracy - epoch_log = epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy) - save_checkpoint(epoch_log, resnet, optimizer) + epoch_log = epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy) + save_checkpoint(args, epoch_log, resnet, optimizer) -- cgit v1.2.3