From 8ddb482b8d3c79009e77bbd15c37f311c6e72aad Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 16 Mar 2022 19:42:05 +0800 Subject: Add ImageNet support --- supervised/baseline.py | 86 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 20 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index 8a1b567..c8bbb37 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -5,20 +5,29 @@ import torch from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10 +from torchvision.datasets import CIFAR10, ImageNet from torchvision.transforms import transforms, InterpolationMode from tqdm import tqdm +from datautils import color_distortion, Clip, RandomGaussianBlur +from models import CIFARResNet50, ImageNetResNet50 from optimizers import LARS from schedulers import LinearWarmupAndCosineAnneal, LinearLR -from supervised.datautils import color_distortion -from supervised.models import CIFAR10ResNet50 -CODENAME = 'cifar10-resnet50-aug-lars-warmup' +CODENAME = 'cifar10-resnet50-256-aug-lars-warmup' DATASET_ROOT = 'dataset' TENSORBOARD_PATH = os.path.join('runs', CODENAME) CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME) +DATASET = 'cifar10' +CROP_SIZE = 32 +CROP_SCALE = (0.8, 1) +HFLIP_P = 0.5 +DISTORT_S = 0.5 +GAUSSIAN_KER_SCALE = 10 +GAUSSIAN_P = 0.5 +GAUSSIAN_SIGMA = (0.1, 2) + BATCH_SIZE = 256 RESTORE_EPOCH = 0 N_EPOCHS = 1000 @@ -36,20 +45,58 @@ random.seed(SEED) torch.manual_seed(SEED) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -train_transform = transforms.Compose([ - transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), - transforms.RandomHorizontalFlip(0.5), - color_distortion(0.5), - transforms.ToTensor() -]) - -test_transform = transforms.Compose([ - transforms.ToTensor() -]) +if DATASET == 'cifar10' or DATASET == 'cifar': + train_transform = transforms.Compose([ + transforms.RandomResizedCrop( + CROP_SIZE, + scale=CROP_SCALE, + interpolation=InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(HFLIP_P), + color_distortion(DISTORT_S), + transforms.ToTensor(), + Clip() + ]) + test_transform = transforms.Compose([ + transforms.ToTensor() + ]) + + train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform, + download=True) + test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform) + + resnet = CIFARResNet50() +elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k': + train_transform = transforms.Compose([ + transforms.RandomResizedCrop( + CROP_SIZE, + scale=CROP_SCALE, + interpolation=InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(HFLIP_P), + color_distortion(DISTORT_S), + transforms.ToTensor(), + RandomGaussianBlur( + kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE, + sigma_range=GAUSSIAN_SIGMA, + p=GAUSSIAN_P + ), + Clip() + ]) + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(CROP_SIZE), + transforms.ToTensor(), + ]) + + train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform) + test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform) + + resnet = ImageNetResNet50() +else: + raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.") -train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform, - download=True) -test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform) +resnet = resnet.to(device) train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_WORKERS) @@ -59,9 +106,6 @@ test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_train_batches = len(train_loader) num_test_batches = len(test_loader) -resnet = CIFAR10ResNet50().to(device) -criterion = CrossEntropyLoss() - def exclude_from_wd_and_adaptation(name): if 'bn' in name: @@ -122,6 +166,8 @@ else: if OPTIM == 'lars': optimizer = LARS(optimizer) +criterion = CrossEntropyLoss() + if not os.path.exists(CHECKPOINT_PATH): os.makedirs(CHECKPOINT_PATH) writer = SummaryWriter(TENSORBOARD_PATH) -- cgit v1.2.3