aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 17:49:51 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 17:49:51 +0800
commit5869d0248fa958acd3447e6bffa8761b91e8e921 (patch)
tree4e2c0744400d9204bdfd23c58bafcf534c2119fb /supervised/baseline.py
parent608178533e93dc7e6fac6059fa139233ab046b63 (diff)
Regular refactoring
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py49
1 files changed, 5 insertions, 44 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index d5671b1..92d8d30 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -2,18 +2,17 @@ import os
import random
import torch
-from torch import nn, Tensor, optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
-from torchvision.models import ResNet
-from torchvision.models.resnet import BasicBlock
from torchvision.transforms import transforms, InterpolationMode
from tqdm import tqdm
-from lars_optimizer import LARS
-from scheduler import LinearWarmupAndCosineAnneal
+from optimizers import LARS
+from schedulers import LinearWarmupAndCosineAnneal
+from supervised.datautils import color_distortion
+from supervised.models import CIFAR10ResNet50
CODENAME = 'cifar10-resnet50-aug-lars-sched'
DATASET_ROOT = 'dataset'
@@ -36,48 +35,10 @@ random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-class CIFAR10ResNet50(ResNet):
- def __init__(self):
- super(CIFAR10ResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
- )
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
- stride=1, padding=1, bias=False)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
-
-def get_color_distortion(s=1.0):
- # s is the strength of color distortion.
- color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
- rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
- rnd_gray = transforms.RandomGrayscale(p=0.2)
- color_distort = transforms.Compose([
- rnd_color_jitter,
- rnd_gray
- ])
- return color_distort
-
-
train_transform = transforms.Compose([
transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(0.5),
- get_color_distortion(0.5),
+ color_distortion(0.5),
transforms.ToTensor()
])