aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py26
1 files changed, 21 insertions, 5 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index dc92408..d400250 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -3,6 +3,7 @@ import os
import random
import torch
+import yaml
from torch.backends import cudnn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
@@ -25,6 +26,17 @@ def build_parser():
except:
raise argparse.ArgumentTypeError("Range must be 'start-end.'")
+ def merge_yaml(args):
+ if args.config:
+ config = yaml.safe_load(args.config)
+ delattr(args, 'config')
+ args_dict = args.__dict__
+ for key, value in config.items():
+ if isinstance(value, str) and 'range' in key:
+ args_dict[key] = range_parser(value)
+ else:
+ args_dict[key] = value
+
parser = argparse.ArgumentParser(description='Supervised baseline')
parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup',
type=str, help="Model descriptor (default: "
@@ -32,6 +44,8 @@ def build_parser():
parser.add_argument('--seed', default=-1, type=int,
help='Random seed for reproducibility '
'(-1 for not set seed) (default: -1)')
+ parser.add_argument('--config', type=argparse.FileType(mode='r'),
+ help='Path to config file (optional)')
data_group = parser.add_argument_group('Dataset parameters')
data_group.add_argument('--dataset_dir', default='dataset', type=str,
@@ -40,7 +54,7 @@ def build_parser():
help="Name of dataset (default: 'cifar10')")
data_group.add_argument('--crop_size', default=32, type=int,
help='Random crop size after resize (default: 32)')
- data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser,
+ data_group.add_argument('--crop_scale_range', default='0.8-1', type=range_parser,
help='Random resize scale range (default: 0.8-1)')
data_group.add_argument('--hflip_p', default=0.5, type=float,
help='Random horizontal flip probability (default: 0.5)')
@@ -49,7 +63,7 @@ def build_parser():
data_group.add_argument('--gaussian_ker_scale', default=10, type=float,
help='Gaussian kernel scale factor '
'(equals to img_size / kernel_size) (default: 10)')
- data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser,
+ data_group.add_argument('--gaussian_sigma_range', default='0.1-2', type=range_parser,
help='Random gaussian blur sigma range (default: 0.1-2)')
data_group.add_argument('--gaussian_p', default=0.5, type=float,
help='Random gaussian blur probability (default: 0.5)')
@@ -88,6 +102,8 @@ def build_parser():
"(default: 'checkpoints')")
args = parser.parse_args()
+ merge_yaml(args)
+
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')
args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv')
@@ -111,7 +127,7 @@ def prepare_dataset(args):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
args.crop_size,
- scale=args.crop_scale,
+ scale=args.crop_scale_range,
interpolation=InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(args.hflip_p),
@@ -130,7 +146,7 @@ def prepare_dataset(args):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
args.crop_size,
- scale=args.crop_scale,
+ scale=args.crop_scale_range,
interpolation=InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(args.hflip_p),
@@ -138,7 +154,7 @@ def prepare_dataset(args):
transforms.ToTensor(),
RandomGaussianBlur(
kernel_size=args.crop_size // args.gaussian_ker_scale,
- sigma_range=args.gaussian_sigma,
+ sigma_range=args.gaussian_sigma_range,
p=args.gaussian_p
),
Clip()