From 0d282a5deac482045037a61063c1e48cb3668cff Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Wed, 16 Mar 2022 20:59:12 +0800
Subject: Add an argument parser

---
 supervised/baseline.py | 202 ++++++++++++++++++++++++++++++-------------------
 1 file changed, 125 insertions(+), 77 deletions(-)

diff --git a/supervised/baseline.py b/supervised/baseline.py
index c8bbb37..4b4f9e1 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,3 +1,4 @@
+import argparse
 import os
 import random
 
@@ -14,46 +15,85 @@ from models import CIFARResNet50, ImageNetResNet50
 from optimizers import LARS
 from schedulers import LinearWarmupAndCosineAnneal, LinearLR
 
-CODENAME = 'cifar10-resnet50-256-aug-lars-warmup'
-DATASET_ROOT = 'dataset'
-TENSORBOARD_PATH = os.path.join('runs', CODENAME)
-CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
-
-DATASET = 'cifar10'
-CROP_SIZE = 32
-CROP_SCALE = (0.8, 1)
-HFLIP_P = 0.5
-DISTORT_S = 0.5
-GAUSSIAN_KER_SCALE = 10
-GAUSSIAN_P = 0.5
-GAUSSIAN_SIGMA = (0.1, 2)
-
-BATCH_SIZE = 256
-RESTORE_EPOCH = 0
-N_EPOCHS = 1000
-WARMUP_EPOCHS = 10
-N_WORKERS = 2
-SEED = 0
-
-OPTIM = 'lars'
-SCHED = 'warmup-anneal'
-LR = 1
-MOMENTUM = 0.9
-WEIGHT_DECAY = 1e-6
-
-random.seed(SEED)
-torch.manual_seed(SEED)
+
+def range_parser(range_string: str):
+    try:
+        range_ = tuple(map(float, range_string.split('-')))
+        return range_
+    except:
+        raise argparse.ArgumentTypeError("Range must be 'start-end.'")
+
+
+parser = argparse.ArgumentParser(description='Supervised baseline')
+parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup',
+                    type=str, help="Model descriptor (default: "
+                                   "'cifar10-resnet50-256-lars-warmup')")
+parser.add_argument('--seed', default=0, type=int,
+                    help='Random seed for reproducibility (default: 0)')
+
+data_group = parser.add_argument_group('Dataset parameters')
+data_group.add_argument('--dataset_dir', default='dataset', type=str,
+                        help="Path to dataset directory (default: 'dataset')")
+data_group.add_argument('--dataset', default='cifar10', type=str,
+                        help="Name of dataset (default: 'cifar10')")
+data_group.add_argument('--crop_size', default=32, type=int,
+                        help='Random crop size after resize (default: 32)')
+data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser,
+                        help='Random resize scale range (default: 0.8-1)')
+data_group.add_argument('--hflip_p', default=0.5, type=float,
+                        help='Random horizontal flip probability (default: 0.5)')
+data_group.add_argument('--distort_s', default=0.5, type=float,
+                        help='Distortion strength (default: 0.5)')
+data_group.add_argument('--gaussian_ker_scale', default=10, type=float,
+                        help='Gaussian kernel scale factor '
+                             '(equals to img_size / kernel_size) (default: 10)')
+data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser,
+                        help='Random gaussian blur sigma range (default: 0.1-2)')
+data_group.add_argument('--gaussian_p', default=0.5, type=float,
+                        help='Random gaussian blur probability (default: 0.5)')
+
+train_group = parser.add_argument_group('Training parameters')
+train_group.add_argument('--batch_size', default=256, type=int,
+                         help='Batch size (default: 256)')
+train_group.add_argument('--restore_epoch', default=0, type=int,
+                         help='Restore epoch, 0 for training from scratch '
+                              '(default: 0)')
+train_group.add_argument('--n_epochs', default=1000, type=int,
+                         help='Number of epochs (default: 1000)')
+train_group.add_argument('--warmup_epochs', default=10, type=int,
+                         help='Epochs for warmup '
+                              '(only for `warmup-anneal` scheduler) (default: 10)')
+train_group.add_argument('--n_workers', default=2, type=int,
+                         help='Number of dataloader processes (default: 2)')
+train_group.add_argument('--optim', default='lars', type=str,
+                         help="Name of optimizer (default: 'lars')")
+train_group.add_argument('--sched', default='warmup-anneal', type=str,
+                         help="Name of scheduler (default: 'warmup-anneal')")
+train_group.add_argument('--lr', default=1, type=float,
+                         help='Learning rate (default: 1)')
+train_group.add_argument('--momentum', default=0.9, type=float,
+                         help='Momentum (default: 0.9')
+train_group.add_argument('--weight_decay', default=1e-6, type=float,
+                         help='Weight decay (l2 regularization) (default: 1e-6)')
+
+args = parser.parse_args()
+
+TENSORBOARD_PATH = os.path.join('runs', args.codename)
+CHECKPOINT_PATH = os.path.join('checkpoints', args.codename)
+
+random.seed(args.seed)
+torch.manual_seed(args.seed)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
-if DATASET == 'cifar10' or DATASET == 'cifar':
+if args.dataset == 'cifar10' or args.dataset == 'cifar':
     train_transform = transforms.Compose([
         transforms.RandomResizedCrop(
-            CROP_SIZE,
-            scale=CROP_SCALE,
+            args.crop_size,
+            scale=args.crop_scale,
             interpolation=InterpolationMode.BICUBIC
         ),
-        transforms.RandomHorizontalFlip(HFLIP_P),
-        color_distortion(DISTORT_S),
+        transforms.RandomHorizontalFlip(args.hflip_p),
+        color_distortion(args.distort_s),
         transforms.ToTensor(),
         Clip()
     ])
@@ -61,47 +101,47 @@ if DATASET == 'cifar10' or DATASET == 'cifar':
         transforms.ToTensor()
     ])
 
-    train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
+    train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform,
                         download=True)
-    test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
+    test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform)
 
     resnet = CIFARResNet50()
-elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k':
+elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k':
     train_transform = transforms.Compose([
         transforms.RandomResizedCrop(
-            CROP_SIZE,
-            scale=CROP_SCALE,
+            args.crop_size,
+            scale=args.crop_scale,
             interpolation=InterpolationMode.BICUBIC
         ),
-        transforms.RandomHorizontalFlip(HFLIP_P),
-        color_distortion(DISTORT_S),
+        transforms.RandomHorizontalFlip(args.hflip_p),
+        color_distortion(args.distort_s),
         transforms.ToTensor(),
         RandomGaussianBlur(
-            kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE,
-            sigma_range=GAUSSIAN_SIGMA,
-            p=GAUSSIAN_P
+            kernel_size=args.crop_size // args.gaussian_ker_scale,
+            sigma_range=args.gaussian_sigma,
+            p=args.gaussian_p
         ),
         Clip()
     ])
     test_transform = transforms.Compose([
         transforms.Resize(256),
-        transforms.CenterCrop(CROP_SIZE),
+        transforms.CenterCrop(args.crop_size),
         transforms.ToTensor(),
     ])
 
-    train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform)
-    test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform)
+    train_set = ImageNet(args.dataset_dir, 'train', transform=train_transform)
+    test_set = ImageNet(args.dataset_dir, 'val', transform=test_transform)
 
     resnet = ImageNetResNet50()
 else:
-    raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.")
+    raise NotImplementedError(f"Dataset '{args.dataset}' is not implemented.")
 
 resnet = resnet.to(device)
 
-train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
-                          shuffle=True, num_workers=N_WORKERS)
-test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
-                         shuffle=False, num_workers=N_WORKERS)
+train_loader = DataLoader(train_set, batch_size=args.batch_size,
+                          shuffle=True, num_workers=args.n_workers)
+test_loader = DataLoader(test_set, batch_size=args.batch_size,
+                         shuffle=False, num_workers=args.n_workers)
 
 num_train_batches = len(train_loader)
 num_test_batches = len(test_loader)
@@ -110,7 +150,7 @@ num_test_batches = len(test_loader)
 def exclude_from_wd_and_adaptation(name):
     if 'bn' in name:
         return True
-    if OPTIM == 'lars' and 'bias' in name:
+    if args.optim == 'lars' and 'bias' in name:
         return True
 
 
@@ -118,7 +158,7 @@ param_groups = [
     {
         'params': [p for name, p in resnet.named_parameters()
                    if not exclude_from_wd_and_adaptation(name)],
-        'weight_decay': WEIGHT_DECAY,
+        'weight_decay': args.weight_decay,
         'layer_adaptation': True,
     },
     {
@@ -128,42 +168,50 @@ param_groups = [
         'layer_adaptation': False,
     },
 ]
-if OPTIM == 'adam':
-    optimizer = torch.optim.Adam(param_groups, lr=LR, betas=(MOMENTUM, 0.999))
-elif OPTIM == 'sdg' or OPTIM == 'lars':
-    optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
+if args.optim == 'adam':
+    optimizer = torch.optim.Adam(
+        param_groups,
+        lr=args.lr,
+        betas=(args.momentum, 0.999)
+    )
+elif args.optim == 'sdg' or args.optim == 'lars':
+    optimizer = torch.optim.SGD(
+        param_groups,
+        lr=args.lr,
+        momentum=args.momentum
+    )
 else:
-    raise NotImplementedError(f"Optimizer '{OPTIM}' is not implemented.")
+    raise NotImplementedError(f"Optimizer '{args.optim}' is not implemented.")
 
 # Restore checkpoint
-if RESTORE_EPOCH > 0:
-    checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt')
+if args.restore_epoch > 0:
+    checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{args.restore_epoch:04d}.pt')
     checkpoint = torch.load(checkpoint_path)
     resnet.load_state_dict(checkpoint['resnet_state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-    print(f'[RESTORED][{RESTORE_EPOCH}/{N_EPOCHS}]\t'
+    print(f'[RESTORED][{args.restore_epoch}/{args.n_epochs}]\t'
           f'Train loss: {checkpoint["train_loss"]:.4f}\t'
           f'Test loss: {checkpoint["test_loss"]:.4f}')
 
-if SCHED == 'warmup-anneal':
+if args.sched == 'warmup-anneal':
     scheduler = LinearWarmupAndCosineAnneal(
         optimizer,
-        warm_up=WARMUP_EPOCHS / N_EPOCHS,
-        T_max=N_EPOCHS * num_train_batches,
-        last_epoch=RESTORE_EPOCH * num_train_batches - 1
+        warm_up=args.warmup_epochs / args.n_epochs,
+        T_max=args.n_epochs * num_train_batches,
+        last_epoch=args.restore_epoch * num_train_batches - 1
     )
-elif SCHED == 'linear':
+elif args.sched == 'linear':
     scheduler = LinearLR(
         optimizer,
-        num_epochs=N_EPOCHS * num_train_batches,
-        last_epoch=RESTORE_EPOCH * num_train_batches - 1
+        num_epochs=args.n_epochs * num_train_batches,
+        last_epoch=args.restore_epoch * num_train_batches - 1
     )
-elif SCHED is None or SCHED == '' or SCHED == 'const':
+elif args.sched is None or args.sched == '' or args.sched == 'const':
     scheduler = None
 else:
-    raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.")
+    raise NotImplementedError(f"Scheduler '{args.sched}' is not implemented.")
 
-if OPTIM == 'lars':
+if args.optim == 'lars':
     optimizer = LARS(optimizer)
 
 criterion = CrossEntropyLoss()
@@ -172,9 +220,9 @@ if not os.path.exists(CHECKPOINT_PATH):
     os.makedirs(CHECKPOINT_PATH)
 writer = SummaryWriter(TENSORBOARD_PATH)
 
-curr_train_iters = RESTORE_EPOCH * num_train_batches
-curr_test_iters = RESTORE_EPOCH * num_test_batches
-for epoch in range(RESTORE_EPOCH, N_EPOCHS):
+curr_train_iters = args.restore_epoch * num_train_batches
+curr_test_iters = args.restore_epoch * num_test_batches
+for epoch in range(args.restore_epoch, args.n_epochs):
     train_loss = 0
     training_progress = tqdm(
         enumerate(train_loader), desc='Train loss: ', total=num_train_batches
@@ -189,7 +237,7 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS):
         loss = criterion(output, targets)
         loss.backward()
         optimizer.step()
-        if SCHED:
+        if args.sched:
             scheduler.step()
 
         train_loss += loss.item()
@@ -224,7 +272,7 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS):
     train_loss_mean = train_loss / num_train_batches
     test_loss_mean = test_loss / num_test_batches
     test_acc_mean = test_acc / num_test_batches
-    print(f'[{epoch + 1}/{N_EPOCHS}]\t'
+    print(f'[{epoch + 1}/{args.n_epochs}]\t'
           f'Train loss: {train_loss_mean:.4f}\t'
           f'Test loss: {test_loss_mean:.4f}\t',
           f'Test acc: {test_acc_mean:.4f}')
-- 
cgit v1.2.3