aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py49
-rw-r--r--supervised/datautils.py13
-rw-r--r--supervised/models.py29
-rw-r--r--supervised/optimizers.py (renamed from supervised/lars_optimizer.py)0
-rw-r--r--supervised/schedulers.py (renamed from supervised/scheduler.py)0
5 files changed, 47 insertions, 44 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index d5671b1..92d8d30 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -2,18 +2,17 @@ import os
import random
import torch
-from torch import nn, Tensor, optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
-from torchvision.models import ResNet
-from torchvision.models.resnet import BasicBlock
from torchvision.transforms import transforms, InterpolationMode
from tqdm import tqdm
-from lars_optimizer import LARS
-from scheduler import LinearWarmupAndCosineAnneal
+from optimizers import LARS
+from schedulers import LinearWarmupAndCosineAnneal
+from supervised.datautils import color_distortion
+from supervised.models import CIFAR10ResNet50
CODENAME = 'cifar10-resnet50-aug-lars-sched'
DATASET_ROOT = 'dataset'
@@ -36,48 +35,10 @@ random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-class CIFAR10ResNet50(ResNet):
- def __init__(self):
- super(CIFAR10ResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
- )
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
- stride=1, padding=1, bias=False)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
-
-def get_color_distortion(s=1.0):
- # s is the strength of color distortion.
- color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
- rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
- rnd_gray = transforms.RandomGrayscale(p=0.2)
- color_distort = transforms.Compose([
- rnd_color_jitter,
- rnd_gray
- ])
- return color_distort
-
-
train_transform = transforms.Compose([
transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(0.5),
- get_color_distortion(0.5),
+ color_distortion(0.5),
transforms.ToTensor()
])
diff --git a/supervised/datautils.py b/supervised/datautils.py
new file mode 100644
index 0000000..196fca7
--- /dev/null
+++ b/supervised/datautils.py
@@ -0,0 +1,13 @@
+from torchvision.transforms import transforms
+
+
+def color_distortion(s=1.0):
+ # s is the strength of color distortion.
+ color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+ rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+ rnd_gray = transforms.RandomGrayscale(p=0.2)
+ color_distort = transforms.Compose([
+ rnd_color_jitter,
+ rnd_gray
+ ])
+ return color_distort
diff --git a/supervised/models.py b/supervised/models.py
new file mode 100644
index 0000000..47a0dcf
--- /dev/null
+++ b/supervised/models.py
@@ -0,0 +1,29 @@
+import torch
+from torch import nn, Tensor
+from torchvision.models import ResNet
+from torchvision.models.resnet import BasicBlock
+
+
+class CIFAR10ResNet50(ResNet):
+ def __init__(self):
+ super(CIFAR10ResNet50, self).__init__(
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
+ )
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
+ stride=1, padding=1, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
diff --git a/supervised/lars_optimizer.py b/supervised/optimizers.py
index 1904e8d..1904e8d 100644
--- a/supervised/lars_optimizer.py
+++ b/supervised/optimizers.py
diff --git a/supervised/scheduler.py b/supervised/schedulers.py
index 828e547..828e547 100644
--- a/supervised/scheduler.py
+++ b/supervised/schedulers.py