diff options
-rw-r--r-- | supervised/baseline.py | 49 | ||||
-rw-r--r-- | supervised/datautils.py | 13 | ||||
-rw-r--r-- | supervised/models.py | 29 | ||||
-rw-r--r-- | supervised/optimizers.py (renamed from supervised/lars_optimizer.py) | 0 | ||||
-rw-r--r-- | supervised/schedulers.py (renamed from supervised/scheduler.py) | 0 |
5 files changed, 47 insertions, 44 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index d5671b1..92d8d30 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -2,18 +2,17 @@ import os import random import torch -from torch import nn, Tensor, optim from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision.datasets import CIFAR10 -from torchvision.models import ResNet -from torchvision.models.resnet import BasicBlock from torchvision.transforms import transforms, InterpolationMode from tqdm import tqdm -from lars_optimizer import LARS -from scheduler import LinearWarmupAndCosineAnneal +from optimizers import LARS +from schedulers import LinearWarmupAndCosineAnneal +from supervised.datautils import color_distortion +from supervised.models import CIFAR10ResNet50 CODENAME = 'cifar10-resnet50-aug-lars-sched' DATASET_ROOT = 'dataset' @@ -36,48 +35,10 @@ random.seed(SEED) torch.manual_seed(SEED) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -class CIFAR10ResNet50(ResNet): - def __init__(self): - super(CIFAR10ResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 - ) - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) - - def forward(self, x: Tensor) -> Tensor: - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - return x - - -def get_color_distortion(s=1.0): - # s is the strength of color distortion. - color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) - rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) - rnd_gray = transforms.RandomGrayscale(p=0.2) - color_distort = transforms.Compose([ - rnd_color_jitter, - rnd_gray - ]) - return color_distort - - train_transform = transforms.Compose([ transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(0.5), - get_color_distortion(0.5), + color_distortion(0.5), transforms.ToTensor() ]) diff --git a/supervised/datautils.py b/supervised/datautils.py new file mode 100644 index 0000000..196fca7 --- /dev/null +++ b/supervised/datautils.py @@ -0,0 +1,13 @@ +from torchvision.transforms import transforms + + +def color_distortion(s=1.0): + # s is the strength of color distortion. + color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([ + rnd_color_jitter, + rnd_gray + ]) + return color_distort diff --git a/supervised/models.py b/supervised/models.py new file mode 100644 index 0000000..47a0dcf --- /dev/null +++ b/supervised/models.py @@ -0,0 +1,29 @@ +import torch +from torch import nn, Tensor +from torchvision.models import ResNet +from torchvision.models.resnet import BasicBlock + + +class CIFAR10ResNet50(ResNet): + def __init__(self): + super(CIFAR10ResNet50, self).__init__( + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + ) + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x diff --git a/supervised/lars_optimizer.py b/supervised/optimizers.py index 1904e8d..1904e8d 100644 --- a/supervised/lars_optimizer.py +++ b/supervised/optimizers.py diff --git a/supervised/scheduler.py b/supervised/schedulers.py index 828e547..828e547 100644 --- a/supervised/scheduler.py +++ b/supervised/schedulers.py |