aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 96017e5..3347cf6 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -271,7 +271,7 @@ def eval(args, test_loader, model, loss_fn):
@training_log
-def batch_logger(writer, batch, epoch, loss, lr):
+def batch_logger(args, writer, batch, epoch, loss, lr):
global_batch = epoch * args.num_train_batches + batch
writer.add_scalar('Batch loss/train', loss, global_batch + 1)
writer.add_scalar('Batch lr/train', lr, global_batch + 1)
@@ -288,7 +288,7 @@ def batch_logger(writer, batch, epoch, loss, lr):
@training_log
-def epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy):
+def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy):
train_loss_mean = train_loss.mean().item()
test_loss_mean = test_loss.mean().item()
test_accuracy_mean = test_accuracy.mean().item()
@@ -305,7 +305,7 @@ def epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy):
}
-def save_checkpoint(epoch_log, model, optimizer):
+def save_checkpoint(args, epoch_log, model, optimizer):
if not os.path.exists(args.checkpoint_root):
os.makedirs(args.checkpoint_root)
@@ -341,9 +341,9 @@ if __name__ == '__main__':
test_accuracy = torch.zeros(args.num_test_batches, device=args.device)
for batch, loss in train(args, train_loader, resnet, xent, optimizer, scheduler):
train_loss[batch] = loss
- batch_logger(writer, batch, epoch, loss, optimizer.param_groups[0]['lr'])
+ batch_logger(args, writer, batch, epoch, loss, optimizer.param_groups[0]['lr'])
for batch, loss, accuracy in eval(args, test_loader, resnet, xent):
test_loss[batch] = loss
test_accuracy[batch] = accuracy
- epoch_log = epoch_logger(writer, epoch, train_loss, test_loss, test_accuracy)
- save_checkpoint(epoch_log, resnet, optimizer)
+ epoch_log = epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy)
+ save_checkpoint(args, epoch_log, resnet, optimizer)