diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 18:55:53 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 18:55:53 +0800 | 
| commit | 1de6b8270c34390f2000cb52a124170dbfda6dc3 (patch) | |
| tree | 3ec746fab834b17f229d3b8e5d7037a4ef4c2ea4 | |
| parent | 48e0405bfbffa63d9e2a610c1cf982892a861e4a (diff) | |
Add YAML parser support
| -rw-r--r-- | supervised/baseline.py | 26 | ||||
| -rw-r--r-- | supervised/config.yaml | 27 | 
2 files changed, 48 insertions, 5 deletions
| diff --git a/supervised/baseline.py b/supervised/baseline.py index dc92408..d400250 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -3,6 +3,7 @@ import os  import random  import torch +import yaml  from torch.backends import cudnn  from torch.nn import CrossEntropyLoss  from torch.utils.data import DataLoader @@ -25,6 +26,17 @@ def build_parser():          except:              raise argparse.ArgumentTypeError("Range must be 'start-end.'") +    def merge_yaml(args): +        if args.config: +            config = yaml.safe_load(args.config) +            delattr(args, 'config') +            args_dict = args.__dict__ +            for key, value in config.items(): +                if isinstance(value, str) and 'range' in key: +                    args_dict[key] = range_parser(value) +                else: +                    args_dict[key] = value +      parser = argparse.ArgumentParser(description='Supervised baseline')      parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup',                          type=str, help="Model descriptor (default: " @@ -32,6 +44,8 @@ def build_parser():      parser.add_argument('--seed', default=-1, type=int,                          help='Random seed for reproducibility '                               '(-1 for not set seed) (default: -1)') +    parser.add_argument('--config', type=argparse.FileType(mode='r'), +                        help='Path to config file (optional)')      data_group = parser.add_argument_group('Dataset parameters')      data_group.add_argument('--dataset_dir', default='dataset', type=str, @@ -40,7 +54,7 @@ def build_parser():                              help="Name of dataset (default: 'cifar10')")      data_group.add_argument('--crop_size', default=32, type=int,                              help='Random crop size after resize (default: 32)') -    data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser, +    data_group.add_argument('--crop_scale_range', default='0.8-1', type=range_parser,                              help='Random resize scale range (default: 0.8-1)')      data_group.add_argument('--hflip_p', default=0.5, type=float,                              help='Random horizontal flip probability (default: 0.5)') @@ -49,7 +63,7 @@ def build_parser():      data_group.add_argument('--gaussian_ker_scale', default=10, type=float,                              help='Gaussian kernel scale factor '                                   '(equals to img_size / kernel_size) (default: 10)') -    data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser, +    data_group.add_argument('--gaussian_sigma_range', default='0.1-2', type=range_parser,                              help='Random gaussian blur sigma range (default: 0.1-2)')      data_group.add_argument('--gaussian_p', default=0.5, type=float,                              help='Random gaussian blur probability (default: 0.5)') @@ -88,6 +102,8 @@ def build_parser():                                      "(default: 'checkpoints')")      args = parser.parse_args() +    merge_yaml(args) +      args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')      args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')      args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv') @@ -111,7 +127,7 @@ def prepare_dataset(args):          train_transform = transforms.Compose([              transforms.RandomResizedCrop(                  args.crop_size, -                scale=args.crop_scale, +                scale=args.crop_scale_range,                  interpolation=InterpolationMode.BICUBIC              ),              transforms.RandomHorizontalFlip(args.hflip_p), @@ -130,7 +146,7 @@ def prepare_dataset(args):          train_transform = transforms.Compose([              transforms.RandomResizedCrop(                  args.crop_size, -                scale=args.crop_scale, +                scale=args.crop_scale_range,                  interpolation=InterpolationMode.BICUBIC              ),              transforms.RandomHorizontalFlip(args.hflip_p), @@ -138,7 +154,7 @@ def prepare_dataset(args):              transforms.ToTensor(),              RandomGaussianBlur(                  kernel_size=args.crop_size // args.gaussian_ker_scale, -                sigma_range=args.gaussian_sigma, +                sigma_range=args.gaussian_sigma_range,                  p=args.gaussian_p              ),              Clip() diff --git a/supervised/config.yaml b/supervised/config.yaml new file mode 100644 index 0000000..2279b3c --- /dev/null +++ b/supervised/config.yaml @@ -0,0 +1,27 @@ +codename: 'cifar10-resnet50-256-lars-warmup' +seed: -1 + +dataset_dir: 'dataset' +dataset: 'cifar10' +crop_size: 32 +crop_scale_range: '0.8-1' +hflip_p: 0.5 +distort_s: 0.5 +gaussian_ker_scale: 10 +gaussian_sigma_range: '0.1-2' +gaussian_p: 0.5 + +batch_size: 256 +restore_epoch: 0 +n_epochs: 1000 +warmup_epochs: 10 +n_workers: 2 +optim: 'lars' +sched: 'warmup-anneal' +lr: 1. +momentum: 0.9 +weight_decay: 1.0e-06 + +log_dir: 'logs' +tensorboard_dir: 'runs' +checkpoint_dir: 'checkpoints'
\ No newline at end of file | 
