aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 19:42:05 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 19:42:05 +0800
commit8ddb482b8d3c79009e77bbd15c37f311c6e72aad (patch)
tree4b6967c400e1b1b27011f97f19073892306a048c /supervised/baseline.py
parent35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (diff)
Add ImageNet support
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py86
1 files changed, 66 insertions, 20 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 8a1b567..c8bbb37 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -5,20 +5,29 @@ import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
-from torchvision.datasets import CIFAR10
+from torchvision.datasets import CIFAR10, ImageNet
from torchvision.transforms import transforms, InterpolationMode
from tqdm import tqdm
+from datautils import color_distortion, Clip, RandomGaussianBlur
+from models import CIFARResNet50, ImageNetResNet50
from optimizers import LARS
from schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from supervised.datautils import color_distortion
-from supervised.models import CIFAR10ResNet50
-CODENAME = 'cifar10-resnet50-aug-lars-warmup'
+CODENAME = 'cifar10-resnet50-256-aug-lars-warmup'
DATASET_ROOT = 'dataset'
TENSORBOARD_PATH = os.path.join('runs', CODENAME)
CHECKPOINT_PATH = os.path.join('checkpoints', CODENAME)
+DATASET = 'cifar10'
+CROP_SIZE = 32
+CROP_SCALE = (0.8, 1)
+HFLIP_P = 0.5
+DISTORT_S = 0.5
+GAUSSIAN_KER_SCALE = 10
+GAUSSIAN_P = 0.5
+GAUSSIAN_SIGMA = (0.1, 2)
+
BATCH_SIZE = 256
RESTORE_EPOCH = 0
N_EPOCHS = 1000
@@ -36,20 +45,58 @@ random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-train_transform = transforms.Compose([
- transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
- transforms.RandomHorizontalFlip(0.5),
- color_distortion(0.5),
- transforms.ToTensor()
-])
-
-test_transform = transforms.Compose([
- transforms.ToTensor()
-])
+if DATASET == 'cifar10' or DATASET == 'cifar':
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(
+ CROP_SIZE,
+ scale=CROP_SCALE,
+ interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.RandomHorizontalFlip(HFLIP_P),
+ color_distortion(DISTORT_S),
+ transforms.ToTensor(),
+ Clip()
+ ])
+ test_transform = transforms.Compose([
+ transforms.ToTensor()
+ ])
+
+ train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
+ download=True)
+ test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
+
+ resnet = CIFARResNet50()
+elif DATASET == 'imagenet1k' or DATASET == 'imagenet1k':
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(
+ CROP_SIZE,
+ scale=CROP_SCALE,
+ interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.RandomHorizontalFlip(HFLIP_P),
+ color_distortion(DISTORT_S),
+ transforms.ToTensor(),
+ RandomGaussianBlur(
+ kernel_size=CROP_SIZE // GAUSSIAN_KER_SCALE,
+ sigma_range=GAUSSIAN_SIGMA,
+ p=GAUSSIAN_P
+ ),
+ Clip()
+ ])
+ test_transform = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(CROP_SIZE),
+ transforms.ToTensor(),
+ ])
+
+ train_set = ImageNet(DATASET_ROOT, 'train', transform=train_transform)
+ test_set = ImageNet(DATASET_ROOT, 'val', transform=test_transform)
+
+ resnet = ImageNetResNet50()
+else:
+ raise NotImplementedError(f"Dataset '{DATASET}' is not implemented.")
-train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
- download=True)
-test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
+resnet = resnet.to(device)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
shuffle=True, num_workers=N_WORKERS)
@@ -59,9 +106,6 @@ test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
num_train_batches = len(train_loader)
num_test_batches = len(test_loader)
-resnet = CIFAR10ResNet50().to(device)
-criterion = CrossEntropyLoss()
-
def exclude_from_wd_and_adaptation(name):
if 'bn' in name:
@@ -122,6 +166,8 @@ else:
if OPTIM == 'lars':
optimizer = LARS(optimizer)
+criterion = CrossEntropyLoss()
+
if not os.path.exists(CHECKPOINT_PATH):
os.makedirs(CHECKPOINT_PATH)
writer = SummaryWriter(TENSORBOARD_PATH)