aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index d400250..221b90d 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -32,8 +32,8 @@ def build_parser():
delattr(args, 'config')
args_dict = args.__dict__
for key, value in config.items():
- if isinstance(value, str) and 'range' in key:
- args_dict[key] = range_parser(value)
+ if isinstance(value, list):
+ args_dict[key] = tuple(value)
else:
args_dict[key] = value
@@ -104,9 +104,8 @@ def build_parser():
args = parser.parse_args()
merge_yaml(args)
- args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')
- args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv')
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ args.log_root = os.path.join(args.log_dir, args.codename)
args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename)
args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename)
@@ -122,6 +121,13 @@ def set_seed(args):
cudnn.deterministic = True
+def init_logging(args):
+ setup_logging(BATCH_LOGGER, os.path.join(args.log_root, 'batch-log.csv'))
+ setup_logging(EPOCH_LOGGER, os.path.join(args.log_root, 'epoch-log.csv'))
+ with open(os.path.join(args.log_root, 'params.yaml'), 'w') as params:
+ yaml.safe_dump(args.__dict__, params)
+
+
def prepare_dataset(args):
if args.dataset == 'cifar10' or args.dataset == 'cifar':
train_transform = transforms.Compose([
@@ -349,8 +355,7 @@ def save_checkpoint(args, epoch_log, model, optimizer):
if __name__ == '__main__':
args = build_parser()
set_seed(args)
- setup_logging(BATCH_LOGGER, args.batch_log_filename)
- setup_logging(EPOCH_LOGGER, args.epoch_log_filename)
+ init_logging(args)
train_set, test_set = prepare_dataset(args)
train_loader, test_loader = create_dataloader(args, train_set, test_set)