aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 18:55:53 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 18:55:53 +0800
commit1de6b8270c34390f2000cb52a124170dbfda6dc3 (patch)
tree3ec746fab834b17f229d3b8e5d7037a4ef4c2ea4
parent48e0405bfbffa63d9e2a610c1cf982892a861e4a (diff)
Add YAML parser support
-rw-r--r--supervised/baseline.py26
-rw-r--r--supervised/config.yaml27
2 files changed, 48 insertions, 5 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index dc92408..d400250 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -3,6 +3,7 @@ import os
import random
import torch
+import yaml
from torch.backends import cudnn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
@@ -25,6 +26,17 @@ def build_parser():
except:
raise argparse.ArgumentTypeError("Range must be 'start-end.'")
+ def merge_yaml(args):
+ if args.config:
+ config = yaml.safe_load(args.config)
+ delattr(args, 'config')
+ args_dict = args.__dict__
+ for key, value in config.items():
+ if isinstance(value, str) and 'range' in key:
+ args_dict[key] = range_parser(value)
+ else:
+ args_dict[key] = value
+
parser = argparse.ArgumentParser(description='Supervised baseline')
parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup',
type=str, help="Model descriptor (default: "
@@ -32,6 +44,8 @@ def build_parser():
parser.add_argument('--seed', default=-1, type=int,
help='Random seed for reproducibility '
'(-1 for not set seed) (default: -1)')
+ parser.add_argument('--config', type=argparse.FileType(mode='r'),
+ help='Path to config file (optional)')
data_group = parser.add_argument_group('Dataset parameters')
data_group.add_argument('--dataset_dir', default='dataset', type=str,
@@ -40,7 +54,7 @@ def build_parser():
help="Name of dataset (default: 'cifar10')")
data_group.add_argument('--crop_size', default=32, type=int,
help='Random crop size after resize (default: 32)')
- data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser,
+ data_group.add_argument('--crop_scale_range', default='0.8-1', type=range_parser,
help='Random resize scale range (default: 0.8-1)')
data_group.add_argument('--hflip_p', default=0.5, type=float,
help='Random horizontal flip probability (default: 0.5)')
@@ -49,7 +63,7 @@ def build_parser():
data_group.add_argument('--gaussian_ker_scale', default=10, type=float,
help='Gaussian kernel scale factor '
'(equals to img_size / kernel_size) (default: 10)')
- data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser,
+ data_group.add_argument('--gaussian_sigma_range', default='0.1-2', type=range_parser,
help='Random gaussian blur sigma range (default: 0.1-2)')
data_group.add_argument('--gaussian_p', default=0.5, type=float,
help='Random gaussian blur probability (default: 0.5)')
@@ -88,6 +102,8 @@ def build_parser():
"(default: 'checkpoints')")
args = parser.parse_args()
+ merge_yaml(args)
+
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')
args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv')
@@ -111,7 +127,7 @@ def prepare_dataset(args):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
args.crop_size,
- scale=args.crop_scale,
+ scale=args.crop_scale_range,
interpolation=InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(args.hflip_p),
@@ -130,7 +146,7 @@ def prepare_dataset(args):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
args.crop_size,
- scale=args.crop_scale,
+ scale=args.crop_scale_range,
interpolation=InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(args.hflip_p),
@@ -138,7 +154,7 @@ def prepare_dataset(args):
transforms.ToTensor(),
RandomGaussianBlur(
kernel_size=args.crop_size // args.gaussian_ker_scale,
- sigma_range=args.gaussian_sigma,
+ sigma_range=args.gaussian_sigma_range,
p=args.gaussian_p
),
Clip()
diff --git a/supervised/config.yaml b/supervised/config.yaml
new file mode 100644
index 0000000..2279b3c
--- /dev/null
+++ b/supervised/config.yaml
@@ -0,0 +1,27 @@
+codename: 'cifar10-resnet50-256-lars-warmup'
+seed: -1
+
+dataset_dir: 'dataset'
+dataset: 'cifar10'
+crop_size: 32
+crop_scale_range: '0.8-1'
+hflip_p: 0.5
+distort_s: 0.5
+gaussian_ker_scale: 10
+gaussian_sigma_range: '0.1-2'
+gaussian_p: 0.5
+
+batch_size: 256
+restore_epoch: 0
+n_epochs: 1000
+warmup_epochs: 10
+n_workers: 2
+optim: 'lars'
+sched: 'warmup-anneal'
+lr: 1.
+momentum: 0.9
+weight_decay: 1.0e-06
+
+log_dir: 'logs'
+tensorboard_dir: 'runs'
+checkpoint_dir: 'checkpoints' \ No newline at end of file