diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 19:12:50 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 19:12:50 +0800 |
commit | 568569c764ffdd73cd660434df50d30d26203f63 (patch) | |
tree | 2df6c4c38f3bc0aba374a111a358ad7ca03dd4f0 /supervised | |
parent | 1de6b8270c34390f2000cb52a124170dbfda6dc3 (diff) |
Log parameters and use list YAML syntax for tuples
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 19 | ||||
-rw-r--r-- | supervised/config.yaml | 24 |
2 files changed, 26 insertions, 17 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index d400250..221b90d 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -32,8 +32,8 @@ def build_parser(): delattr(args, 'config') args_dict = args.__dict__ for key, value in config.items(): - if isinstance(value, str) and 'range' in key: - args_dict[key] = range_parser(value) + if isinstance(value, list): + args_dict[key] = tuple(value) else: args_dict[key] = value @@ -104,9 +104,8 @@ def build_parser(): args = parser.parse_args() merge_yaml(args) - args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv') - args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv') + args.device = 'cuda' if torch.cuda.is_available() else 'cpu' + args.log_root = os.path.join(args.log_dir, args.codename) args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename) args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename) @@ -122,6 +121,13 @@ def set_seed(args): cudnn.deterministic = True +def init_logging(args): + setup_logging(BATCH_LOGGER, os.path.join(args.log_root, 'batch-log.csv')) + setup_logging(EPOCH_LOGGER, os.path.join(args.log_root, 'epoch-log.csv')) + with open(os.path.join(args.log_root, 'params.yaml'), 'w') as params: + yaml.safe_dump(args.__dict__, params) + + def prepare_dataset(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': train_transform = transforms.Compose([ @@ -349,8 +355,7 @@ def save_checkpoint(args, epoch_log, model, optimizer): if __name__ == '__main__': args = build_parser() set_seed(args) - setup_logging(BATCH_LOGGER, args.batch_log_filename) - setup_logging(EPOCH_LOGGER, args.epoch_log_filename) + init_logging(args) train_set, test_set = prepare_dataset(args) train_loader, test_loader = create_dataloader(args, train_set, test_set) diff --git a/supervised/config.yaml b/supervised/config.yaml index 2279b3c..4fb37fb 100644 --- a/supervised/config.yaml +++ b/supervised/config.yaml @@ -1,14 +1,18 @@ -codename: 'cifar10-resnet50-256-lars-warmup' +codename: cifar10-resnet50-256-lars-warmup seed: -1 -dataset_dir: 'dataset' -dataset: 'cifar10' +dataset_dir: dataset +dataset: cifar10 crop_size: 32 -crop_scale_range: '0.8-1' +crop_scale_range: +- 0.8 +- 1 hflip_p: 0.5 distort_s: 0.5 gaussian_ker_scale: 10 -gaussian_sigma_range: '0.1-2' +gaussian_sigma_range: +- 0.1 +- 2 gaussian_p: 0.5 batch_size: 256 @@ -16,12 +20,12 @@ restore_epoch: 0 n_epochs: 1000 warmup_epochs: 10 n_workers: 2 -optim: 'lars' -sched: 'warmup-anneal' +optim: lars +sched: warmup-anneal lr: 1. momentum: 0.9 weight_decay: 1.0e-06 -log_dir: 'logs' -tensorboard_dir: 'runs' -checkpoint_dir: 'checkpoints'
\ No newline at end of file +log_dir: logs +tensorboard_dir: runs +checkpoint_dir: checkpoints
\ No newline at end of file |