diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 49 |
1 files changed, 24 insertions, 25 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 5e1b32e..e7ff8f1 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -23,59 +23,58 @@ from models import CIFARResNet50 def parse_args_and_config(): - parser = argparse.ArgumentParser(description='Supervised baseline') + parser = argparse.ArgumentParser( + description='Supervised baseline', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument('--codename', default='cifar10-resnet50-256-adam-linear', - type=str, help="Model descriptor (default: " - "'cifar10-resnet50-256-adam-linear')") + type=str, help="Model descriptor") parser.add_argument('--log_dir', default='logs', type=str, - help="Path to log directory (default: 'logs')") + help="Path to log directory") parser.add_argument('--checkpoint_dir', default='checkpoints', type=str, - help="Path to checkpoints directory (default: 'checkpoints')") - parser.add_argument('--seed', default=-1, type=int, - help='Random seed for reproducibility ' - '(-1 for not set seed) (default: -1)') + help="Path to checkpoints directory") + parser.add_argument('--seed', default=None, type=int, + help='Random seed for reproducibility') parser.add_argument('--num_iters', default=1000, type=int, - help='Number of iters (epochs) (default: 1000)') + help='Number of iters (epochs)') parser.add_argument('--config', type=argparse.FileType(mode='r'), help='Path to config file (optional)') dataset_group = parser.add_argument_group('Dataset parameters') dataset_group.add_argument('--dataset_dir', default='dataset', type=str, - help="Path to dataset directory (default: 'dataset')") + help="Path to dataset directory") dataset_group.add_argument('--dataset', default='cifar10', type=str, choices=('cifar', 'cifar10, cifar100'), - help="Name of dataset (default: 'cifar10')") + help="Name of dataset") dataset_group.add_argument('--crop_size', default=32, type=int, - help='Random crop size after resize (default: 32)') - dataset_group.add_argument('--crop_scale_range', nargs=2, default=(0.8, 1), type=float, - help='Random resize scale range (default: 0.8 1)', + help='Random crop size after resize') + dataset_group.add_argument('--crop_scale_range', nargs=2, default=(0.8, 1), + type=float, help='Random resize scale range', metavar=('start', 'stop')) dataset_group.add_argument('--hflip_prob', default=0.5, type=float, - help='Random horizontal flip probability (default: 0.5)') + help='Random horizontal flip probability') dataloader_group = parser.add_argument_group('Dataloader parameters') dataloader_group.add_argument('--batch_size', default=256, type=int, - help='Batch size (default: 256)') + help='Batch size') dataloader_group.add_argument('--num_workers', default=2, type=int, - help='Number of dataloader processes (default: 2)') + help='Number of dataloader processes') optim_group = parser.add_argument_group('Optimizer parameters') optim_group.add_argument('--optim', default='adam', type=str, - choices=('adam', 'sgd'), - help="Name of optimizer (default: 'adam')") + choices=('adam', 'sgd'), help="Name of optimizer") optim_group.add_argument('--lr', default=1e-3, type=float, - help='Learning rate (default: 1)') + help='Learning rate') optim_group.add_argument('--betas', nargs=2, default=(0.9, 0.999), type=float, - help='Adam betas (default: 0.9 0.999)', metavar=('beta1', 'beta2')) + help='Adam betas', metavar=('beta1', 'beta2')) optim_group.add_argument('--momentum', default=0.9, type=float, - help='SDG momentum (default: 0.9)') + help='SDG momentum') optim_group.add_argument('--weight_decay', default=1e-6, type=float, - help='Weight decay (l2 regularization) (default: 1e-6)') + help='Weight decay (l2 regularization)') sched_group = parser.add_argument_group('Optimizer parameters') sched_group.add_argument('--sched', default='linear', type=str, - choices=(None, '', 'linear'), - help="Name of scheduler (default: None)") + choices=(None, '', 'linear'), help="Name of scheduler") args = parser.parse_args() if args.config: |