diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index e21aeee..dc92408 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from datautils import color_distortion, Clip, RandomGaussianBlur from models import CIFARResNet50, ImageNetResNet50 from optimizers import LARS from schedulers import LinearWarmupAndCosineAnneal, LinearLR -from utils import training_log +from utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER def build_parser(): @@ -78,10 +78,21 @@ def build_parser(): train_group.add_argument('--weight_decay', default=1e-6, type=float, help='Weight decay (l2 regularization) (default: 1e-6)') + logging_group = parser.add_argument_group('Logging config') + logging_group.add_argument('--log_dir', default='logs', type=str, + help="Path to log directory (default: 'logs')") + logging_group.add_argument('--tensorboard_dir', default='runs', type=str, + help="Path to tensorboard directory (default: 'runs')") + logging_group.add_argument('--checkpoint_dir', default='checkpoints', type=str, + help='Path to checkpoints directory ' + "(default: 'checkpoints')") + args = parser.parse_args() args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - args.checkpoint_root = os.path.join('checkpoints', args.codename) - args.tensorboard_root = os.path.join('runs', args.codename) + args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv') + args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv') + args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename) + args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename) return args @@ -208,7 +219,7 @@ def configure_optimizer(args, model): return optimizer -@training_log +@training_log(EPOCH_LOGGER) def load_checkpoint(args, model, optimizer): checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt') checkpoint = torch.load(checkpoint_path) @@ -275,7 +286,7 @@ def eval(args, test_loader, model, loss_fn): yield batch, loss.item(), accuracy.item() -@training_log +@training_log(BATCH_LOGGER) def batch_logger(args, writer, batch, epoch, loss, lr): global_batch = epoch * args.num_train_batches + batch writer.add_scalar('Batch loss/train', loss, global_batch + 1) @@ -292,7 +303,7 @@ def batch_logger(args, writer, batch, epoch, loss, lr): } -@training_log +@training_log(EPOCH_LOGGER) def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): train_loss_mean = train_loss.mean().item() test_loss_mean = test_loss.mean().item() @@ -311,8 +322,7 @@ def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): def save_checkpoint(args, epoch_log, model, optimizer): - if not os.path.exists(args.checkpoint_root): - os.makedirs(args.checkpoint_root) + os.makedirs(args.checkpoint_root, exist_ok=True) torch.save(epoch_log | { 'model_state_dict': model.state_dict(), @@ -323,6 +333,8 @@ def save_checkpoint(args, epoch_log, model, optimizer): if __name__ == '__main__': args = build_parser() set_seed(args) + setup_logging(BATCH_LOGGER, args.batch_log_filename) + setup_logging(EPOCH_LOGGER, args.epoch_log_filename) train_set, test_set = prepare_dataset(args) train_loader, test_loader = create_dataloader(args, train_set, test_set) |