diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index dc92408..d400250 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -3,6 +3,7 @@ import os import random import torch +import yaml from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader @@ -25,6 +26,17 @@ def build_parser(): except: raise argparse.ArgumentTypeError("Range must be 'start-end.'") + def merge_yaml(args): + if args.config: + config = yaml.safe_load(args.config) + delattr(args, 'config') + args_dict = args.__dict__ + for key, value in config.items(): + if isinstance(value, str) and 'range' in key: + args_dict[key] = range_parser(value) + else: + args_dict[key] = value + parser = argparse.ArgumentParser(description='Supervised baseline') parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', type=str, help="Model descriptor (default: " @@ -32,6 +44,8 @@ def build_parser(): parser.add_argument('--seed', default=-1, type=int, help='Random seed for reproducibility ' '(-1 for not set seed) (default: -1)') + parser.add_argument('--config', type=argparse.FileType(mode='r'), + help='Path to config file (optional)') data_group = parser.add_argument_group('Dataset parameters') data_group.add_argument('--dataset_dir', default='dataset', type=str, @@ -40,7 +54,7 @@ def build_parser(): help="Name of dataset (default: 'cifar10')") data_group.add_argument('--crop_size', default=32, type=int, help='Random crop size after resize (default: 32)') - data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser, + data_group.add_argument('--crop_scale_range', default='0.8-1', type=range_parser, help='Random resize scale range (default: 0.8-1)') data_group.add_argument('--hflip_p', default=0.5, type=float, help='Random horizontal flip probability (default: 0.5)') @@ -49,7 +63,7 @@ def build_parser(): data_group.add_argument('--gaussian_ker_scale', default=10, type=float, help='Gaussian kernel scale factor ' '(equals to img_size / kernel_size) (default: 10)') - data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser, + data_group.add_argument('--gaussian_sigma_range', default='0.1-2', type=range_parser, help='Random gaussian blur sigma range (default: 0.1-2)') data_group.add_argument('--gaussian_p', default=0.5, type=float, help='Random gaussian blur probability (default: 0.5)') @@ -88,6 +102,8 @@ def build_parser(): "(default: 'checkpoints')") args = parser.parse_args() + merge_yaml(args) + args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv') args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv') @@ -111,7 +127,7 @@ def prepare_dataset(args): train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, - scale=args.crop_scale, + scale=args.crop_scale_range, interpolation=InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(args.hflip_p), @@ -130,7 +146,7 @@ def prepare_dataset(args): train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, - scale=args.crop_scale, + scale=args.crop_scale_range, interpolation=InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(args.hflip_p), @@ -138,7 +154,7 @@ def prepare_dataset(args): transforms.ToTensor(), RandomGaussianBlur( kernel_size=args.crop_size // args.gaussian_ker_scale, - sigma_range=args.gaussian_sigma, + sigma_range=args.gaussian_sigma_range, p=args.gaussian_p ), Clip() |