diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 18:55:53 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 18:55:53 +0800 |
commit | 1de6b8270c34390f2000cb52a124170dbfda6dc3 (patch) | |
tree | 3ec746fab834b17f229d3b8e5d7037a4ef4c2ea4 /supervised | |
parent | 48e0405bfbffa63d9e2a610c1cf982892a861e4a (diff) |
Add YAML parser support
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 26 | ||||
-rw-r--r-- | supervised/config.yaml | 27 |
2 files changed, 48 insertions, 5 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index dc92408..d400250 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -3,6 +3,7 @@ import os import random import torch +import yaml from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader @@ -25,6 +26,17 @@ def build_parser(): except: raise argparse.ArgumentTypeError("Range must be 'start-end.'") + def merge_yaml(args): + if args.config: + config = yaml.safe_load(args.config) + delattr(args, 'config') + args_dict = args.__dict__ + for key, value in config.items(): + if isinstance(value, str) and 'range' in key: + args_dict[key] = range_parser(value) + else: + args_dict[key] = value + parser = argparse.ArgumentParser(description='Supervised baseline') parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', type=str, help="Model descriptor (default: " @@ -32,6 +44,8 @@ def build_parser(): parser.add_argument('--seed', default=-1, type=int, help='Random seed for reproducibility ' '(-1 for not set seed) (default: -1)') + parser.add_argument('--config', type=argparse.FileType(mode='r'), + help='Path to config file (optional)') data_group = parser.add_argument_group('Dataset parameters') data_group.add_argument('--dataset_dir', default='dataset', type=str, @@ -40,7 +54,7 @@ def build_parser(): help="Name of dataset (default: 'cifar10')") data_group.add_argument('--crop_size', default=32, type=int, help='Random crop size after resize (default: 32)') - data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser, + data_group.add_argument('--crop_scale_range', default='0.8-1', type=range_parser, help='Random resize scale range (default: 0.8-1)') data_group.add_argument('--hflip_p', default=0.5, type=float, help='Random horizontal flip probability (default: 0.5)') @@ -49,7 +63,7 @@ def build_parser(): data_group.add_argument('--gaussian_ker_scale', default=10, type=float, help='Gaussian kernel scale factor ' '(equals to img_size / kernel_size) (default: 10)') - data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser, + data_group.add_argument('--gaussian_sigma_range', default='0.1-2', type=range_parser, help='Random gaussian blur sigma range (default: 0.1-2)') data_group.add_argument('--gaussian_p', default=0.5, type=float, help='Random gaussian blur probability (default: 0.5)') @@ -88,6 +102,8 @@ def build_parser(): "(default: 'checkpoints')") args = parser.parse_args() + merge_yaml(args) + args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv') args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv') @@ -111,7 +127,7 @@ def prepare_dataset(args): train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, - scale=args.crop_scale, + scale=args.crop_scale_range, interpolation=InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(args.hflip_p), @@ -130,7 +146,7 @@ def prepare_dataset(args): train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, - scale=args.crop_scale, + scale=args.crop_scale_range, interpolation=InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(args.hflip_p), @@ -138,7 +154,7 @@ def prepare_dataset(args): transforms.ToTensor(), RandomGaussianBlur( kernel_size=args.crop_size // args.gaussian_ker_scale, - sigma_range=args.gaussian_sigma, + sigma_range=args.gaussian_sigma_range, p=args.gaussian_p ), Clip() diff --git a/supervised/config.yaml b/supervised/config.yaml new file mode 100644 index 0000000..2279b3c --- /dev/null +++ b/supervised/config.yaml @@ -0,0 +1,27 @@ +codename: 'cifar10-resnet50-256-lars-warmup' +seed: -1 + +dataset_dir: 'dataset' +dataset: 'cifar10' +crop_size: 32 +crop_scale_range: '0.8-1' +hflip_p: 0.5 +distort_s: 0.5 +gaussian_ker_scale: 10 +gaussian_sigma_range: '0.1-2' +gaussian_p: 0.5 + +batch_size: 256 +restore_epoch: 0 +n_epochs: 1000 +warmup_epochs: 10 +n_workers: 2 +optim: 'lars' +sched: 'warmup-anneal' +lr: 1. +momentum: 0.9 +weight_decay: 1.0e-06 + +log_dir: 'logs' +tensorboard_dir: 'runs' +checkpoint_dir: 'checkpoints'
\ No newline at end of file |