diff options
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 86 | ||||
-rw-r--r-- | supervised/datautils.py | 54 | ||||
-rw-r--r-- | supervised/models.py | 11 |
3 files changed, 129 insertions, 22 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 8a1b567..c8bbb37 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -5,20 +5,29 @@ import torch from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10 +from torchvision.datasets import CIFAR10, ImageNet from torchvision.transforms import transforms, InterpolationMode from tqdm import tqdm +from datautils import color_distortion, Clip, RandomGaussianBlur +from models import CIFARResNet50, ImageNetResNet50 from optimizers import LARS from schedulers import LinearWarmupAndCosineAnneal, LinearLR -from supervised.datautils import color_distortion -from supervised.models import CIFAR10ResNet50 -CODENAME = 'cifar10-resnet50-aug-lars-warmup' +CODENAME = 'cifar10-resnet50-256-aug-lars-warmup' DATASET_ROOT = 'dataset' TENSORBOARD_PATH = os.path.join('runs', CODENAME) CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME) +DATASET = 'cifar10' +CROP_SIZE = 32 +CROP_SCALE = (0.8, 1) +HFLIP_P = 0.5 +DISTORT_S = 0.5 +GAUSSIAN_KER_SCALE = 10 +GAUSSIAN_P = 0.5 +GAUSSIAN_SIGMA = (0.1, 2) + BATCH_SIZE = 256 RESTORE_EPOCH = 0 N_EPOCHS = 1000 @@ -36,20 +45,58 @@ random.seed(SEED) torch.manual_seed(SEED) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -train_transform = transforms.Compose([ - transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), - transforms.RandomHorizontalFlip(0.5), - color_distortion(0.5), - transforms.ToTensor() -]) - -test_transform = transforms.Compose([ - transforms.ToTensor() -]) +if DATASET == 'cifar10' or DATASET == 'cifar': + train_transform = transforms.Compose([ + transforms.RandomResizedCrop( + CROP_SIZE, + scale=CROP_SCALE, + interpolation=InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(HFLIP_P), + color_distortion(DISTORT_S), + transforms.ToTensor(), + Clip() + ]) + test_transform = transforms.Compose([ + transforms.ToTensor() + ]) + + train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform, + download=True) + test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform) + + resnet = CIFARResNet50() +elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k': + train_transform = transforms.Compose([ + transforms.RandomResizedCrop( + CROP_SIZE, + scale=CROP_SCALE, + interpolation=InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(HFLIP_P), + color_distortion(DISTORT_S), + transforms.ToTensor(), + RandomGaussianBlur( + kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE, + sigma_range=GAUSSIAN_SIGMA, + p=GAUSSIAN_P + ), + Clip() + ]) + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(CROP_SIZE), + transforms.ToTensor(), + ]) + + train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform) + test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform) + + resnet = ImageNetResNet50() +else: + raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.") -train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform, - download=True) -test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform) +resnet = resnet.to(device) train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_WORKERS) @@ -59,9 +106,6 @@ test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_train_batches = len(train_loader) num_test_batches = len(test_loader) -resnet = CIFAR10ResNet50().to(device) -criterion = CrossEntropyLoss() - def exclude_from_wd_and_adaptation(name): if 'bn' in name: @@ -122,6 +166,8 @@ else: if OPTIM == 'lars': optimizer = LARS(optimizer) +criterion = CrossEntropyLoss() + if not os.path.exists(CHECKPOINT_PATH): os.makedirs(CHECKPOINT_PATH) writer = SummaryWriter(TENSORBOARD_PATH) diff --git a/supervised/datautils.py b/supervised/datautils.py index 196fca7..843f669 100644 --- a/supervised/datautils.py +++ b/supervised/datautils.py @@ -1,3 +1,5 @@ +import numpy as np +import torch from torchvision.transforms import transforms @@ -11,3 +13,55 @@ def color_distortion(s=1.0): rnd_gray ]) return color_distort + + +class Clip(object): + def __call__(self, x): + return torch.clamp(x, 0, 1) + + +class RandomGaussianBlur(object): + """ + PyTorch version of + https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311 + """ + + def gaussian_blur(self, image, sigma): + image = image.reshape(1, 3, 224, 224) + radius = np.int(self.kernel_size / 2) + kernel_size = radius * 2 + 1 + x = np.arange(-radius, radius + 1) + + blur_filter = np.exp( + -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0))) + blur_filter /= np.sum(blur_filter) + + conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, + padding=[kernel_size // 2, 0], bias=False) + conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile( + blur_filter.reshape(kernel_size, 1, 1, 1), 3 + ).transpose([3, 2, 0, 1]))) + + conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, + padding=[0, kernel_size // 2], bias=False) + conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile( + blur_filter.reshape(kernel_size, 1, 1, 1), 3 + ).transpose([3, 2, 1, 0]))) + + res = conv2(conv1(image)) + assert res.shape == image.shape + return res[0] + + def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5): + self.kernel_size = kernel_size + self.sigma_range = sigma_range + self.p = p + + def __call__(self, img): + with torch.no_grad(): + assert isinstance(img, torch.Tensor) + if np.random.uniform() < self.p: + return self.gaussian_blur( + img, sigma=np.random.uniform(*self.sigma_range) + ) + return img diff --git a/supervised/models.py b/supervised/models.py index 47a0dcf..f3ba35a 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -4,9 +4,9 @@ from torchvision.models import ResNet from torchvision.models.resnet import BasicBlock -class CIFAR10ResNet50(ResNet): +class CIFARResNet50(ResNet): def __init__(self): - super(CIFAR10ResNet50, self).__init__( + super(CIFARResNet50, self).__init__( block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, @@ -27,3 +27,10 @@ class CIFAR10ResNet50(ResNet): x = self.fc(x) return x + + +class ImageNetResNet50(ResNet): + def __init__(self): + super(ImageNetResNet50, self).__init__( + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=1000 + ) |