aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--supervised/baseline.py202
1 files changed, 125 insertions, 77 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index c8bbb37..4b4f9e1 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,3 +1,4 @@
+import argparse
import os
import random
@@ -14,46 +15,85 @@ from models import CIFARResNet50, ImageNetResNet50
from optimizers import LARS
from schedulers import LinearWarmupAndCosineAnneal, LinearLR
-CODENAME = 'cifar10-resnet50-256-aug-lars-warmup'
-DATASET_ROOT = 'dataset'
-TENSORBOARD_PATH = os.path.join('runs', CODENAME)
-CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
-
-DATASET = 'cifar10'
-CROP_SIZE = 32
-CROP_SCALE = (0.8, 1)
-HFLIP_P = 0.5
-DISTORT_S = 0.5
-GAUSSIAN_KER_SCALE = 10
-GAUSSIAN_P = 0.5
-GAUSSIAN_SIGMA = (0.1, 2)
-
-BATCH_SIZE = 256
-RESTORE_EPOCH = 0
-N_EPOCHS = 1000
-WARMUP_EPOCHS = 10
-N_WORKERS = 2
-SEED = 0
-
-OPTIM = 'lars'
-SCHED = 'warmup-anneal'
-LR = 1
-MOMENTUM = 0.9
-WEIGHT_DECAY = 1e-6
-
-random.seed(SEED)
-torch.manual_seed(SEED)
+
+def range_parser(range_string: str):
+ try:
+ range_ = tuple(map(float, range_string.split('-')))
+ return range_
+ except:
+ raise argparse.ArgumentTypeError("Range must be 'start-end.'")
+
+
+parser = argparse.ArgumentParser(description='Supervised baseline')
+parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup',
+ type=str, help="Model descriptor (default: "
+ "'cifar10-resnet50-256-lars-warmup')")
+parser.add_argument('--seed', default=0, type=int,
+ help='Random seed for reproducibility (default: 0)')
+
+data_group = parser.add_argument_group('Dataset parameters')
+data_group.add_argument('--dataset_dir', default='dataset', type=str,
+ help="Path to dataset directory (default: 'dataset')")
+data_group.add_argument('--dataset', default='cifar10', type=str,
+ help="Name of dataset (default: 'cifar10')")
+data_group.add_argument('--crop_size', default=32, type=int,
+ help='Random crop size after resize (default: 32)')
+data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser,
+ help='Random resize scale range (default: 0.8-1)')
+data_group.add_argument('--hflip_p', default=0.5, type=float,
+ help='Random horizontal flip probability (default: 0.5)')
+data_group.add_argument('--distort_s', default=0.5, type=float,
+ help='Distortion strength (default: 0.5)')
+data_group.add_argument('--gaussian_ker_scale', default=10, type=float,
+ help='Gaussian kernel scale factor '
+ '(equals to img_size / kernel_size) (default: 10)')
+data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser,
+ help='Random gaussian blur sigma range (default: 0.1-2)')
+data_group.add_argument('--gaussian_p', default=0.5, type=float,
+ help='Random gaussian blur probability (default: 0.5)')
+
+train_group = parser.add_argument_group('Training parameters')
+train_group.add_argument('--batch_size', default=256, type=int,
+ help='Batch size (default: 256)')
+train_group.add_argument('--restore_epoch', default=0, type=int,
+ help='Restore epoch, 0 for training from scratch '
+ '(default: 0)')
+train_group.add_argument('--n_epochs', default=1000, type=int,
+ help='Number of epochs (default: 1000)')
+train_group.add_argument('--warmup_epochs', default=10, type=int,
+ help='Epochs for warmup '
+ '(only for `warmup-anneal` scheduler) (default: 10)')
+train_group.add_argument('--n_workers', default=2, type=int,
+ help='Number of dataloader processes (default: 2)')
+train_group.add_argument('--optim', default='lars', type=str,
+ help="Name of optimizer (default: 'lars')")
+train_group.add_argument('--sched', default='warmup-anneal', type=str,
+ help="Name of scheduler (default: 'warmup-anneal')")
+train_group.add_argument('--lr', default=1, type=float,
+ help='Learning rate (default: 1)')
+train_group.add_argument('--momentum', default=0.9, type=float,
+ help='Momentum (default: 0.9')
+train_group.add_argument('--weight_decay', default=1e-6, type=float,
+ help='Weight decay (l2 regularization) (default: 1e-6)')
+
+args = parser.parse_args()
+
+TENSORBOARD_PATH = os.path.join('runs', args.codename)
+CHECKPOINT_PATH = os.path.join('checkpoints', args.codename)
+
+random.seed(args.seed)
+torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-if DATASET == 'cifar10' or DATASET == 'cifar':
+if args.dataset == 'cifar10' or args.dataset == 'cifar':
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
- CROP_SIZE,
- scale=CROP_SCALE,
+ args.crop_size,
+ scale=args.crop_scale,
interpolation=InterpolationMode.BICUBIC
),
- transforms.RandomHorizontalFlip(HFLIP_P),
- color_distortion(DISTORT_S),
+ transforms.RandomHorizontalFlip(args.hflip_p),
+ color_distortion(args.distort_s),
transforms.ToTensor(),
Clip()
])
@@ -61,47 +101,47 @@ if DATASET == 'cifar10' or DATASET == 'cifar':
transforms.ToTensor()
])
- train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
+ train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform,
download=True)
- test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
+ test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform)
resnet = CIFARResNet50()
-elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k':
+elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k':
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
- CROP_SIZE,
- scale=CROP_SCALE,
+ args.crop_size,
+ scale=args.crop_scale,
interpolation=InterpolationMode.BICUBIC
),
- transforms.RandomHorizontalFlip(HFLIP_P),
- color_distortion(DISTORT_S),
+ transforms.RandomHorizontalFlip(args.hflip_p),
+ color_distortion(args.distort_s),
transforms.ToTensor(),
RandomGaussianBlur(
- kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE,
- sigma_range=GAUSSIAN_SIGMA,
- p=GAUSSIAN_P
+ kernel_size=args.crop_size // args.gaussian_ker_scale,
+ sigma_range=args.gaussian_sigma,
+ p=args.gaussian_p
),
Clip()
])
test_transform = transforms.Compose([
transforms.Resize(256),
- transforms.CenterCrop(CROP_SIZE),
+ transforms.CenterCrop(args.crop_size),
transforms.ToTensor(),
])
- train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform)
- test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform)
+ train_set = ImageNet(args.dataset_dir, 'train', transform=train_transform)
+ test_set = ImageNet(args.dataset_dir, 'val', transform=test_transform)
resnet = ImageNetResNet50()
else:
- raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.")
+ raise NotImplementedError(f"Dataset '{args.dataset}' is not implemented.")
resnet = resnet.to(device)
-train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
- shuffle=True, num_workers=N_WORKERS)
-test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
- shuffle=False, num_workers=N_WORKERS)
+train_loader = DataLoader(train_set, batch_size=args.batch_size,
+ shuffle=True, num_workers=args.n_workers)
+test_loader = DataLoader(test_set, batch_size=args.batch_size,
+ shuffle=False, num_workers=args.n_workers)
num_train_batches = len(train_loader)
num_test_batches = len(test_loader)
@@ -110,7 +150,7 @@ num_test_batches = len(test_loader)
def exclude_from_wd_and_adaptation(name):
if 'bn' in name:
return True
- if OPTIM == 'lars' and 'bias' in name:
+ if args.optim == 'lars' and 'bias' in name:
return True
@@ -118,7 +158,7 @@ param_groups = [
{
'params': [p for name, p in resnet.named_parameters()
if not exclude_from_wd_and_adaptation(name)],
- 'weight_decay': WEIGHT_DECAY,
+ 'weight_decay': args.weight_decay,
'layer_adaptation': True,
},
{
@@ -128,42 +168,50 @@ param_groups = [
'layer_adaptation': False,
},
]
-if OPTIM == 'adam':
- optimizer = torch.optim.Adam(param_groups, lr=LR, betas=(MOMENTUM, 0.999))
-elif OPTIM == 'sdg' or OPTIM == 'lars':
- optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
+if args.optim == 'adam':
+ optimizer = torch.optim.Adam(
+ param_groups,
+ lr=args.lr,
+ betas=(args.momentum, 0.999)
+ )
+elif args.optim == 'sdg' or args.optim == 'lars':
+ optimizer = torch.optim.SGD(
+ param_groups,
+ lr=args.lr,
+ momentum=args.momentum
+ )
else:
- raise NotImplementedError(f"Optimizer '{OPTIM}' is not implemented.")
+ raise NotImplementedError(f"Optimizer '{args.optim}' is not implemented.")
# Restore checkpoint
-if RESTORE_EPOCH > 0:
- checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt')
+if args.restore_epoch > 0:
+ checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{args.restore_epoch:04d}.pt')
checkpoint = torch.load(checkpoint_path)
resnet.load_state_dict(checkpoint['resnet_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- print(f'[RESTORED][{RESTORE_EPOCH}/{N_EPOCHS}]\t'
+ print(f'[RESTORED][{args.restore_epoch}/{args.n_epochs}]\t'
f'Train loss: {checkpoint["train_loss"]:.4f}\t'
f'Test loss: {checkpoint["test_loss"]:.4f}')
-if SCHED == 'warmup-anneal':
+if args.sched == 'warmup-anneal':
scheduler = LinearWarmupAndCosineAnneal(
optimizer,
- warm_up=WARMUP_EPOCHS / N_EPOCHS,
- T_max=N_EPOCHS * num_train_batches,
- last_epoch=RESTORE_EPOCH * num_train_batches - 1
+ warm_up=args.warmup_epochs / args.n_epochs,
+ T_max=args.n_epochs * num_train_batches,
+ last_epoch=args.restore_epoch * num_train_batches - 1
)
-elif SCHED == 'linear':
+elif args.sched == 'linear':
scheduler = LinearLR(
optimizer,
- num_epochs=N_EPOCHS * num_train_batches,
- last_epoch=RESTORE_EPOCH * num_train_batches - 1
+ num_epochs=args.n_epochs * num_train_batches,
+ last_epoch=args.restore_epoch * num_train_batches - 1
)
-elif SCHED is None or SCHED == '' or SCHED == 'const':
+elif args.sched is None or args.sched == '' or args.sched == 'const':
scheduler = None
else:
- raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.")
+ raise NotImplementedError(f"Scheduler '{args.sched}' is not implemented.")
-if OPTIM == 'lars':
+if args.optim == 'lars':
optimizer = LARS(optimizer)
criterion = CrossEntropyLoss()
@@ -172,9 +220,9 @@ if not os.path.exists(CHECKPOINT_PATH):
os.makedirs(CHECKPOINT_PATH)
writer = SummaryWriter(TENSORBOARD_PATH)
-curr_train_iters = RESTORE_EPOCH * num_train_batches
-curr_test_iters = RESTORE_EPOCH * num_test_batches
-for epoch in range(RESTORE_EPOCH, N_EPOCHS):
+curr_train_iters = args.restore_epoch * num_train_batches
+curr_test_iters = args.restore_epoch * num_test_batches
+for epoch in range(args.restore_epoch, args.n_epochs):
train_loss = 0
training_progress = tqdm(
enumerate(train_loader), desc='Train loss: ', total=num_train_batches
@@ -189,7 +237,7 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS):
loss = criterion(output, targets)
loss.backward()
optimizer.step()
- if SCHED:
+ if args.sched:
scheduler.step()
train_loss += loss.item()
@@ -224,7 +272,7 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS):
train_loss_mean = train_loss / num_train_batches
test_loss_mean = test_loss / num_test_batches
test_acc_mean = test_acc / num_test_batches
- print(f'[{epoch + 1}/{N_EPOCHS}]\t'
+ print(f'[{epoch + 1}/{args.n_epochs}]\t'
f'Train loss: {train_loss_mean:.4f}\t'
f'Test loss: {test_loss_mean:.4f}\t',
f'Test acc: {test_acc_mean:.4f}')