diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 17:49:51 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 17:49:51 +0800 |
commit | 5869d0248fa958acd3447e6bffa8761b91e8e921 (patch) | |
tree | 4e2c0744400d9204bdfd23c58bafcf534c2119fb /supervised/baseline.py | |
parent | 608178533e93dc7e6fac6059fa139233ab046b63 (diff) |
Regular refactoring
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 49 |
1 files changed, 5 insertions, 44 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index d5671b1..92d8d30 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -2,18 +2,17 @@ import os import random import torch -from torch import nn, Tensor, optim from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision.datasets import CIFAR10 -from torchvision.models import ResNet -from torchvision.models.resnet import BasicBlock from torchvision.transforms import transforms, InterpolationMode from tqdm import tqdm -from lars_optimizer import LARS -from scheduler import LinearWarmupAndCosineAnneal +from optimizers import LARS +from schedulers import LinearWarmupAndCosineAnneal +from supervised.datautils import color_distortion +from supervised.models import CIFAR10ResNet50 CODENAME = 'cifar10-resnet50-aug-lars-sched' DATASET_ROOT = 'dataset' @@ -36,48 +35,10 @@ random.seed(SEED) torch.manual_seed(SEED) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -class CIFAR10ResNet50(ResNet): - def __init__(self): - super(CIFAR10ResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 - ) - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) - - def forward(self, x: Tensor) -> Tensor: - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - return x - - -def get_color_distortion(s=1.0): - # s is the strength of color distortion. - color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) - rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) - rnd_gray = transforms.RandomGrayscale(p=0.2) - color_distort = transforms.Compose([ - rnd_color_jitter, - rnd_gray - ]) - return color_distort - - train_transform = transforms.Compose([ transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(0.5), - get_color_distortion(0.5), + color_distortion(0.5), transforms.ToTensor() ]) |