aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 19:42:05 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 19:42:05 +0800
commit8ddb482b8d3c79009e77bbd15c37f311c6e72aad (patch)
tree4b6967c400e1b1b27011f97f19073892306a048c
parent35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (diff)
Add ImageNet support
-rw-r--r--supervised/baseline.py86
-rw-r--r--supervised/datautils.py54
-rw-r--r--supervised/models.py11
3 files changed, 129 insertions, 22 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 8a1b567..c8bbb37 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -5,20 +5,29 @@ import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
-from torchvision.datasets import CIFAR10
+from torchvision.datasets import CIFAR10, ImageNet
from torchvision.transforms import transforms, InterpolationMode
from tqdm import tqdm
+from datautils import color_distortion, Clip, RandomGaussianBlur
+from models import CIFARResNet50, ImageNetResNet50
from optimizers import LARS
from schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from supervised.datautils import color_distortion
-from supervised.models import CIFAR10ResNet50
-CODENAME = 'cifar10-resnet50-aug-lars-warmup'
+CODENAME = 'cifar10-resnet50-256-aug-lars-warmup'
DATASET_ROOT = 'dataset'
TENSORBOARD_PATH = os.path.join('runs', CODENAME)
CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
+DATASET = 'cifar10'
+CROP_SIZE = 32
+CROP_SCALE = (0.8, 1)
+HFLIP_P = 0.5
+DISTORT_S = 0.5
+GAUSSIAN_KER_SCALE = 10
+GAUSSIAN_P = 0.5
+GAUSSIAN_SIGMA = (0.1, 2)
+
BATCH_SIZE = 256
RESTORE_EPOCH = 0
N_EPOCHS = 1000
@@ -36,20 +45,58 @@ random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-train_transform = transforms.Compose([
- transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
- transforms.RandomHorizontalFlip(0.5),
- color_distortion(0.5),
- transforms.ToTensor()
-])
-
-test_transform = transforms.Compose([
- transforms.ToTensor()
-])
+if DATASET == 'cifar10' or DATASET == 'cifar':
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(
+ CROP_SIZE,
+ scale=CROP_SCALE,
+ interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.RandomHorizontalFlip(HFLIP_P),
+ color_distortion(DISTORT_S),
+ transforms.ToTensor(),
+ Clip()
+ ])
+ test_transform = transforms.Compose([
+ transforms.ToTensor()
+ ])
+
+ train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
+ download=True)
+ test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
+
+ resnet = CIFARResNet50()
+elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k':
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(
+ CROP_SIZE,
+ scale=CROP_SCALE,
+ interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.RandomHorizontalFlip(HFLIP_P),
+ color_distortion(DISTORT_S),
+ transforms.ToTensor(),
+ RandomGaussianBlur(
+ kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE,
+ sigma_range=GAUSSIAN_SIGMA,
+ p=GAUSSIAN_P
+ ),
+ Clip()
+ ])
+ test_transform = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(CROP_SIZE),
+ transforms.ToTensor(),
+ ])
+
+ train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform)
+ test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform)
+
+ resnet = ImageNetResNet50()
+else:
+ raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.")
-train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
- download=True)
-test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
+resnet = resnet.to(device)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
shuffle=True, num_workers=N_WORKERS)
@@ -59,9 +106,6 @@ test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
num_train_batches = len(train_loader)
num_test_batches = len(test_loader)
-resnet = CIFAR10ResNet50().to(device)
-criterion = CrossEntropyLoss()
-
def exclude_from_wd_and_adaptation(name):
if 'bn' in name:
@@ -122,6 +166,8 @@ else:
if OPTIM == 'lars':
optimizer = LARS(optimizer)
+criterion = CrossEntropyLoss()
+
if not os.path.exists(CHECKPOINT_PATH):
os.makedirs(CHECKPOINT_PATH)
writer = SummaryWriter(TENSORBOARD_PATH)
diff --git a/supervised/datautils.py b/supervised/datautils.py
index 196fca7..843f669 100644
--- a/supervised/datautils.py
+++ b/supervised/datautils.py
@@ -1,3 +1,5 @@
+import numpy as np
+import torch
from torchvision.transforms import transforms
@@ -11,3 +13,55 @@ def color_distortion(s=1.0):
rnd_gray
])
return color_distort
+
+
+class Clip(object):
+ def __call__(self, x):
+ return torch.clamp(x, 0, 1)
+
+
+class RandomGaussianBlur(object):
+ """
+ PyTorch version of
+ https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311
+ """
+
+ def gaussian_blur(self, image, sigma):
+ image = image.reshape(1, 3, 224, 224)
+ radius = np.int(self.kernel_size / 2)
+ kernel_size = radius * 2 + 1
+ x = np.arange(-radius, radius + 1)
+
+ blur_filter = np.exp(
+ -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0)))
+ blur_filter /= np.sum(blur_filter)
+
+ conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3,
+ padding=[kernel_size // 2, 0], bias=False)
+ conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile(
+ blur_filter.reshape(kernel_size, 1, 1, 1), 3
+ ).transpose([3, 2, 0, 1])))
+
+ conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3,
+ padding=[0, kernel_size // 2], bias=False)
+ conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile(
+ blur_filter.reshape(kernel_size, 1, 1, 1), 3
+ ).transpose([3, 2, 1, 0])))
+
+ res = conv2(conv1(image))
+ assert res.shape == image.shape
+ return res[0]
+
+ def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5):
+ self.kernel_size = kernel_size
+ self.sigma_range = sigma_range
+ self.p = p
+
+ def __call__(self, img):
+ with torch.no_grad():
+ assert isinstance(img, torch.Tensor)
+ if np.random.uniform() < self.p:
+ return self.gaussian_blur(
+ img, sigma=np.random.uniform(*self.sigma_range)
+ )
+ return img
diff --git a/supervised/models.py b/supervised/models.py
index 47a0dcf..f3ba35a 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -4,9 +4,9 @@ from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock
-class CIFAR10ResNet50(ResNet):
+class CIFARResNet50(ResNet):
def __init__(self):
- super(CIFAR10ResNet50, self).__init__(
+ super(CIFARResNet50, self).__init__(
block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
@@ -27,3 +27,10 @@ class CIFAR10ResNet50(ResNet):
x = self.fc(x)
return x
+
+
+class ImageNetResNet50(ResNet):
+ def __init__(self):
+ super(ImageNetResNet50, self).__init__(
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=1000
+ )