aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--supervised/baseline.py19
-rw-r--r--supervised/config.yaml24
2 files changed, 26 insertions, 17 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index d400250..221b90d 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -32,8 +32,8 @@ def build_parser():
delattr(args, 'config')
args_dict = args.__dict__
for key, value in config.items():
- if isinstance(value, str) and 'range' in key:
- args_dict[key] = range_parser(value)
+ if isinstance(value, list):
+ args_dict[key] = tuple(value)
else:
args_dict[key] = value
@@ -104,9 +104,8 @@ def build_parser():
args = parser.parse_args()
merge_yaml(args)
- args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')
- args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv')
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ args.log_root = os.path.join(args.log_dir, args.codename)
args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename)
args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename)
@@ -122,6 +121,13 @@ def set_seed(args):
cudnn.deterministic = True
+def init_logging(args):
+ setup_logging(BATCH_LOGGER, os.path.join(args.log_root, 'batch-log.csv'))
+ setup_logging(EPOCH_LOGGER, os.path.join(args.log_root, 'epoch-log.csv'))
+ with open(os.path.join(args.log_root, 'params.yaml'), 'w') as params:
+ yaml.safe_dump(args.__dict__, params)
+
+
def prepare_dataset(args):
if args.dataset == 'cifar10' or args.dataset == 'cifar':
train_transform = transforms.Compose([
@@ -349,8 +355,7 @@ def save_checkpoint(args, epoch_log, model, optimizer):
if __name__ == '__main__':
args = build_parser()
set_seed(args)
- setup_logging(BATCH_LOGGER, args.batch_log_filename)
- setup_logging(EPOCH_LOGGER, args.epoch_log_filename)
+ init_logging(args)
train_set, test_set = prepare_dataset(args)
train_loader, test_loader = create_dataloader(args, train_set, test_set)
diff --git a/supervised/config.yaml b/supervised/config.yaml
index 2279b3c..4fb37fb 100644
--- a/supervised/config.yaml
+++ b/supervised/config.yaml
@@ -1,14 +1,18 @@
-codename: 'cifar10-resnet50-256-lars-warmup'
+codename: cifar10-resnet50-256-lars-warmup
seed: -1
-dataset_dir: 'dataset'
-dataset: 'cifar10'
+dataset_dir: dataset
+dataset: cifar10
crop_size: 32
-crop_scale_range: '0.8-1'
+crop_scale_range:
+- 0.8
+- 1
hflip_p: 0.5
distort_s: 0.5
gaussian_ker_scale: 10
-gaussian_sigma_range: '0.1-2'
+gaussian_sigma_range:
+- 0.1
+- 2
gaussian_p: 0.5
batch_size: 256
@@ -16,12 +20,12 @@ restore_epoch: 0
n_epochs: 1000
warmup_epochs: 10
n_workers: 2
-optim: 'lars'
-sched: 'warmup-anneal'
+optim: lars
+sched: warmup-anneal
lr: 1.
momentum: 0.9
weight_decay: 1.0e-06
-log_dir: 'logs'
-tensorboard_dir: 'runs'
-checkpoint_dir: 'checkpoints' \ No newline at end of file
+log_dir: logs
+tensorboard_dir: runs
+checkpoint_dir: checkpoints \ No newline at end of file