diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 202 |
1 files changed, 125 insertions, 77 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index c8bbb37..4b4f9e1 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,3 +1,4 @@ +import argparse import os import random @@ -14,46 +15,85 @@ from models import CIFARResNet50, ImageNetResNet50 from optimizers import LARS from schedulers import LinearWarmupAndCosineAnneal, LinearLR -CODENAME = 'cifar10-resnet50-256-aug-lars-warmup' -DATASET_ROOT = 'dataset' -TENSORBOARD_PATH = os.path.join('runs', CODENAME) -CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME) - -DATASET = 'cifar10' -CROP_SIZE = 32 -CROP_SCALE = (0.8, 1) -HFLIP_P = 0.5 -DISTORT_S = 0.5 -GAUSSIAN_KER_SCALE = 10 -GAUSSIAN_P = 0.5 -GAUSSIAN_SIGMA = (0.1, 2) - -BATCH_SIZE = 256 -RESTORE_EPOCH = 0 -N_EPOCHS = 1000 -WARMUP_EPOCHS = 10 -N_WORKERS = 2 -SEED = 0 - -OPTIM = 'lars' -SCHED = 'warmup-anneal' -LR = 1 -MOMENTUM = 0.9 -WEIGHT_DECAY = 1e-6 - -random.seed(SEED) -torch.manual_seed(SEED) + +def range_parser(range_string: str): + try: + range_ = tuple(map(float, range_string.split('-'))) + return range_ + except: + raise argparse.ArgumentTypeError("Range must be 'start-end.'") + + +parser = argparse.ArgumentParser(description='Supervised baseline') +parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', + type=str, help="Model descriptor (default: " + "'cifar10-resnet50-256-lars-warmup')") +parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducibility (default: 0)') + +data_group = parser.add_argument_group('Dataset parameters') +data_group.add_argument('--dataset_dir', default='dataset', type=str, + help="Path to dataset directory (default: 'dataset')") +data_group.add_argument('--dataset', default='cifar10', type=str, + help="Name of dataset (default: 'cifar10')") +data_group.add_argument('--crop_size', default=32, type=int, + help='Random crop size after resize (default: 32)') +data_group.add_argument('--crop_scale', default='0.8-1', type=range_parser, + help='Random resize scale range (default: 0.8-1)') +data_group.add_argument('--hflip_p', default=0.5, type=float, + help='Random horizontal flip probability (default: 0.5)') +data_group.add_argument('--distort_s', default=0.5, type=float, + help='Distortion strength (default: 0.5)') +data_group.add_argument('--gaussian_ker_scale', default=10, type=float, + help='Gaussian kernel scale factor ' + '(equals to img_size / kernel_size) (default: 10)') +data_group.add_argument('--gaussian_sigma', default='0.1-2', type=range_parser, + help='Random gaussian blur sigma range (default: 0.1-2)') +data_group.add_argument('--gaussian_p', default=0.5, type=float, + help='Random gaussian blur probability (default: 0.5)') + +train_group = parser.add_argument_group('Training parameters') +train_group.add_argument('--batch_size', default=256, type=int, + help='Batch size (default: 256)') +train_group.add_argument('--restore_epoch', default=0, type=int, + help='Restore epoch, 0 for training from scratch ' + '(default: 0)') +train_group.add_argument('--n_epochs', default=1000, type=int, + help='Number of epochs (default: 1000)') +train_group.add_argument('--warmup_epochs', default=10, type=int, + help='Epochs for warmup ' + '(only for `warmup-anneal` scheduler) (default: 10)') +train_group.add_argument('--n_workers', default=2, type=int, + help='Number of dataloader processes (default: 2)') +train_group.add_argument('--optim', default='lars', type=str, + help="Name of optimizer (default: 'lars')") +train_group.add_argument('--sched', default='warmup-anneal', type=str, + help="Name of scheduler (default: 'warmup-anneal')") +train_group.add_argument('--lr', default=1, type=float, + help='Learning rate (default: 1)') +train_group.add_argument('--momentum', default=0.9, type=float, + help='Momentum (default: 0.9') +train_group.add_argument('--weight_decay', default=1e-6, type=float, + help='Weight decay (l2 regularization) (default: 1e-6)') + +args = parser.parse_args() + +TENSORBOARD_PATH = os.path.join('runs', args.codename) +CHECKPOINT_PATH = os.path.join('checkpoints', args.codename) + +random.seed(args.seed) +torch.manual_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -if DATASET == 'cifar10' or DATASET == 'cifar': +if args.dataset == 'cifar10' or args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop( - CROP_SIZE, - scale=CROP_SCALE, + args.crop_size, + scale=args.crop_scale, interpolation=InterpolationMode.BICUBIC ), - transforms.RandomHorizontalFlip(HFLIP_P), - color_distortion(DISTORT_S), + transforms.RandomHorizontalFlip(args.hflip_p), + color_distortion(args.distort_s), transforms.ToTensor(), Clip() ]) @@ -61,47 +101,47 @@ if DATASET == 'cifar10' or DATASET == 'cifar': transforms.ToTensor() ]) - train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform, + train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform, download=True) - test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform) + test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform) resnet = CIFARResNet50() -elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k': +elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': train_transform = transforms.Compose([ transforms.RandomResizedCrop( - CROP_SIZE, - scale=CROP_SCALE, + args.crop_size, + scale=args.crop_scale, interpolation=InterpolationMode.BICUBIC ), - transforms.RandomHorizontalFlip(HFLIP_P), - color_distortion(DISTORT_S), + transforms.RandomHorizontalFlip(args.hflip_p), + color_distortion(args.distort_s), transforms.ToTensor(), RandomGaussianBlur( - kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE, - sigma_range=GAUSSIAN_SIGMA, - p=GAUSSIAN_P + kernel_size=args.crop_size // args.gaussian_ker_scale, + sigma_range=args.gaussian_sigma, + p=args.gaussian_p ), Clip() ]) test_transform = transforms.Compose([ transforms.Resize(256), - transforms.CenterCrop(CROP_SIZE), + transforms.CenterCrop(args.crop_size), transforms.ToTensor(), ]) - train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform) - test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform) + train_set = ImageNet(args.dataset_dir, 'train', transform=train_transform) + test_set = ImageNet(args.dataset_dir, 'val', transform=test_transform) resnet = ImageNetResNet50() else: - raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.") + raise NotImplementedError(f"Dataset '{args.dataset}' is not implemented.") resnet = resnet.to(device) -train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, - shuffle=True, num_workers=N_WORKERS) -test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, - shuffle=False, num_workers=N_WORKERS) +train_loader = DataLoader(train_set, batch_size=args.batch_size, + shuffle=True, num_workers=args.n_workers) +test_loader = DataLoader(test_set, batch_size=args.batch_size, + shuffle=False, num_workers=args.n_workers) num_train_batches = len(train_loader) num_test_batches = len(test_loader) @@ -110,7 +150,7 @@ num_test_batches = len(test_loader) def exclude_from_wd_and_adaptation(name): if 'bn' in name: return True - if OPTIM == 'lars' and 'bias' in name: + if args.optim == 'lars' and 'bias' in name: return True @@ -118,7 +158,7 @@ param_groups = [ { 'params': [p for name, p in resnet.named_parameters() if not exclude_from_wd_and_adaptation(name)], - 'weight_decay': WEIGHT_DECAY, + 'weight_decay': args.weight_decay, 'layer_adaptation': True, }, { @@ -128,42 +168,50 @@ param_groups = [ 'layer_adaptation': False, }, ] -if OPTIM == 'adam': - optimizer = torch.optim.Adam(param_groups, lr=LR, betas=(MOMENTUM, 0.999)) -elif OPTIM == 'sdg' or OPTIM == 'lars': - optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) +if args.optim == 'adam': + optimizer = torch.optim.Adam( + param_groups, + lr=args.lr, + betas=(args.momentum, 0.999) + ) +elif args.optim == 'sdg' or args.optim == 'lars': + optimizer = torch.optim.SGD( + param_groups, + lr=args.lr, + momentum=args.momentum + ) else: - raise NotImplementedError(f"Optimizer '{OPTIM}' is not implemented.") + raise NotImplementedError(f"Optimizer '{args.optim}' is not implemented.") # Restore checkpoint -if RESTORE_EPOCH > 0: - checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{RESTORE_EPOCH:04d}.pt') +if args.restore_epoch > 0: + checkpoint_path = os.path.join(CHECKPOINT_PATH, f'{args.restore_epoch:04d}.pt') checkpoint = torch.load(checkpoint_path) resnet.load_state_dict(checkpoint['resnet_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - print(f'[RESTORED][{RESTORE_EPOCH}/{N_EPOCHS}]\t' + print(f'[RESTORED][{args.restore_epoch}/{args.n_epochs}]\t' f'Train loss: {checkpoint["train_loss"]:.4f}\t' f'Test loss: {checkpoint["test_loss"]:.4f}') -if SCHED == 'warmup-anneal': +if args.sched == 'warmup-anneal': scheduler = LinearWarmupAndCosineAnneal( optimizer, - warm_up=WARMUP_EPOCHS / N_EPOCHS, - T_max=N_EPOCHS * num_train_batches, - last_epoch=RESTORE_EPOCH * num_train_batches - 1 + warm_up=args.warmup_epochs / args.n_epochs, + T_max=args.n_epochs * num_train_batches, + last_epoch=args.restore_epoch * num_train_batches - 1 ) -elif SCHED == 'linear': +elif args.sched == 'linear': scheduler = LinearLR( optimizer, - num_epochs=N_EPOCHS * num_train_batches, - last_epoch=RESTORE_EPOCH * num_train_batches - 1 + num_epochs=args.n_epochs * num_train_batches, + last_epoch=args.restore_epoch * num_train_batches - 1 ) -elif SCHED is None or SCHED == '' or SCHED == 'const': +elif args.sched is None or args.sched == '' or args.sched == 'const': scheduler = None else: - raise NotImplementedError(f"Scheduler '{SCHED}' is not implemented.") + raise NotImplementedError(f"Scheduler '{args.sched}' is not implemented.") -if OPTIM == 'lars': +if args.optim == 'lars': optimizer = LARS(optimizer) criterion = CrossEntropyLoss() @@ -172,9 +220,9 @@ if not os.path.exists(CHECKPOINT_PATH): os.makedirs(CHECKPOINT_PATH) writer = SummaryWriter(TENSORBOARD_PATH) -curr_train_iters = RESTORE_EPOCH * num_train_batches -curr_test_iters = RESTORE_EPOCH * num_test_batches -for epoch in range(RESTORE_EPOCH, N_EPOCHS): +curr_train_iters = args.restore_epoch * num_train_batches +curr_test_iters = args.restore_epoch * num_test_batches +for epoch in range(args.restore_epoch, args.n_epochs): train_loss = 0 training_progress = tqdm( enumerate(train_loader), desc='Train loss: ', total=num_train_batches @@ -189,7 +237,7 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS): loss = criterion(output, targets) loss.backward() optimizer.step() - if SCHED: + if args.sched: scheduler.step() train_loss += loss.item() @@ -224,7 +272,7 @@ for epoch in range(RESTORE_EPOCH, N_EPOCHS): train_loss_mean = train_loss / num_train_batches test_loss_mean = test_loss / num_test_batches test_acc_mean = test_acc / num_test_batches - print(f'[{epoch + 1}/{N_EPOCHS}]\t' + print(f'[{epoch + 1}/{args.n_epochs}]\t' f'Train loss: {train_loss_mean:.4f}\t' f'Test loss: {test_loss_mean:.4f}\t', f'Test acc: {test_acc_mean:.4f}') |