From 5869d0248fa958acd3447e6bffa8761b91e8e921 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Wed, 16 Mar 2022 17:49:51 +0800
Subject: Regular refactoring

---
 supervised/baseline.py       | 49 ++++---------------------------
 supervised/datautils.py      | 13 ++++++++
 supervised/lars_optimizer.py | 70 --------------------------------------------
 supervised/models.py         | 29 ++++++++++++++++++
 supervised/optimizers.py     | 70 ++++++++++++++++++++++++++++++++++++++++++++
 supervised/scheduler.py      | 29 ------------------
 supervised/schedulers.py     | 29 ++++++++++++++++++
 7 files changed, 146 insertions(+), 143 deletions(-)
 create mode 100644 supervised/datautils.py
 delete mode 100644 supervised/lars_optimizer.py
 create mode 100644 supervised/models.py
 create mode 100644 supervised/optimizers.py
 delete mode 100644 supervised/scheduler.py
 create mode 100644 supervised/schedulers.py

(limited to 'supervised')

diff --git a/supervised/baseline.py b/supervised/baseline.py
index d5671b1..92d8d30 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -2,18 +2,17 @@ import os
 import random
 
 import torch
-from torch import nn, Tensor, optim
 from torch.nn import CrossEntropyLoss
 from torch.utils.data import DataLoader
 from torch.utils.tensorboard import SummaryWriter
 from torchvision.datasets import CIFAR10
-from torchvision.models import ResNet
-from torchvision.models.resnet import BasicBlock
 from torchvision.transforms import transforms, InterpolationMode
 from tqdm import tqdm
 
-from lars_optimizer import LARS
-from scheduler import LinearWarmupAndCosineAnneal
+from optimizers import LARS
+from schedulers import LinearWarmupAndCosineAnneal
+from supervised.datautils import color_distortion
+from supervised.models import CIFAR10ResNet50
 
 CODENAME = 'cifar10-resnet50-aug-lars-sched'
 DATASET_ROOT = 'dataset'
@@ -36,48 +35,10 @@ random.seed(SEED)
 torch.manual_seed(SEED)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
-
-class CIFAR10ResNet50(ResNet):
-    def __init__(self):
-        super(CIFAR10ResNet50, self).__init__(
-            block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
-        )
-        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
-                               stride=1, padding=1, bias=False)
-
-    def forward(self, x: Tensor) -> Tensor:
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-
-        x = self.layer1(x)
-        x = self.layer2(x)
-        x = self.layer3(x)
-        x = self.layer4(x)
-
-        x = self.avgpool(x)
-        x = torch.flatten(x, 1)
-        x = self.fc(x)
-
-        return x
-
-
-def get_color_distortion(s=1.0):
-    # s is the strength of color distortion.
-    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
-    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
-    rnd_gray = transforms.RandomGrayscale(p=0.2)
-    color_distort = transforms.Compose([
-        rnd_color_jitter,
-        rnd_gray
-    ])
-    return color_distort
-
-
 train_transform = transforms.Compose([
     transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
     transforms.RandomHorizontalFlip(0.5),
-    get_color_distortion(0.5),
+    color_distortion(0.5),
     transforms.ToTensor()
 ])
 
diff --git a/supervised/datautils.py b/supervised/datautils.py
new file mode 100644
index 0000000..196fca7
--- /dev/null
+++ b/supervised/datautils.py
@@ -0,0 +1,13 @@
+from torchvision.transforms import transforms
+
+
+def color_distortion(s=1.0):
+    # s is the strength of color distortion.
+    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+    rnd_gray = transforms.RandomGrayscale(p=0.2)
+    color_distort = transforms.Compose([
+        rnd_color_jitter,
+        rnd_gray
+    ])
+    return color_distort
diff --git a/supervised/lars_optimizer.py b/supervised/lars_optimizer.py
deleted file mode 100644
index 1904e8d..0000000
--- a/supervised/lars_optimizer.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import torch
-
-
-class LARS(object):
-    """
-    Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
-    Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
-
-    Args:
-        optimizer: Pytorch optimizer to wrap and modify learning rate for.
-        trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
-    """
-
-    def __init__(self,
-                 optimizer,
-                 trust_coefficient=0.001,
-                 ):
-        self.param_groups = optimizer.param_groups
-        self.optim = optimizer
-        self.trust_coefficient = trust_coefficient
-
-    def __getstate__(self):
-        return self.optim.__getstate__()
-
-    def __setstate__(self, state):
-        self.optim.__setstate__(state)
-
-    def __repr__(self):
-        return self.optim.__repr__()
-
-    def state_dict(self):
-        return self.optim.state_dict()
-
-    def load_state_dict(self, state_dict):
-        self.optim.load_state_dict(state_dict)
-
-    def zero_grad(self):
-        self.optim.zero_grad()
-
-    def add_param_group(self, param_group):
-        self.optim.add_param_group(param_group)
-
-    def step(self):
-        with torch.no_grad():
-            weight_decays = []
-            for group in self.optim.param_groups:
-                # absorb weight decay control from optimizer
-                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
-                weight_decays.append(weight_decay)
-                group['weight_decay'] = 0
-                for p in group['params']:
-                    if p.grad is None:
-                        continue
-
-                    if weight_decay != 0:
-                        p.grad.data += weight_decay * p.data
-
-                    param_norm = torch.norm(p.data)
-                    grad_norm = torch.norm(p.grad.data)
-                    adaptive_lr = 1.
-
-                    if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
-                        adaptive_lr = self.trust_coefficient * param_norm / grad_norm
-
-                    p.grad.data *= adaptive_lr
-
-        self.optim.step()
-        # return weight decay control to optimizer
-        for i, group in enumerate(self.optim.param_groups):
-            group['weight_decay'] = weight_decays[i]
diff --git a/supervised/models.py b/supervised/models.py
new file mode 100644
index 0000000..47a0dcf
--- /dev/null
+++ b/supervised/models.py
@@ -0,0 +1,29 @@
+import torch
+from torch import nn, Tensor
+from torchvision.models import ResNet
+from torchvision.models.resnet import BasicBlock
+
+
+class CIFAR10ResNet50(ResNet):
+    def __init__(self):
+        super(CIFAR10ResNet50, self).__init__(
+            block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
+        )
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
+                               stride=1, padding=1, bias=False)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.fc(x)
+
+        return x
diff --git a/supervised/optimizers.py b/supervised/optimizers.py
new file mode 100644
index 0000000..1904e8d
--- /dev/null
+++ b/supervised/optimizers.py
@@ -0,0 +1,70 @@
+import torch
+
+
+class LARS(object):
+    """
+    Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
+    Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
+
+    Args:
+        optimizer: Pytorch optimizer to wrap and modify learning rate for.
+        trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
+    """
+
+    def __init__(self,
+                 optimizer,
+                 trust_coefficient=0.001,
+                 ):
+        self.param_groups = optimizer.param_groups
+        self.optim = optimizer
+        self.trust_coefficient = trust_coefficient
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return self.optim.__repr__()
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self.optim.load_state_dict(state_dict)
+
+    def zero_grad(self):
+        self.optim.zero_grad()
+
+    def add_param_group(self, param_group):
+        self.optim.add_param_group(param_group)
+
+    def step(self):
+        with torch.no_grad():
+            weight_decays = []
+            for group in self.optim.param_groups:
+                # absorb weight decay control from optimizer
+                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
+                weight_decays.append(weight_decay)
+                group['weight_decay'] = 0
+                for p in group['params']:
+                    if p.grad is None:
+                        continue
+
+                    if weight_decay != 0:
+                        p.grad.data += weight_decay * p.data
+
+                    param_norm = torch.norm(p.data)
+                    grad_norm = torch.norm(p.grad.data)
+                    adaptive_lr = 1.
+
+                    if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
+                        adaptive_lr = self.trust_coefficient * param_norm / grad_norm
+
+                    p.grad.data *= adaptive_lr
+
+        self.optim.step()
+        # return weight decay control to optimizer
+        for i, group in enumerate(self.optim.param_groups):
+            group['weight_decay'] = weight_decays[i]
diff --git a/supervised/scheduler.py b/supervised/scheduler.py
deleted file mode 100644
index 828e547..0000000
--- a/supervised/scheduler.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import warnings
-
-import numpy as np
-import torch
-
-
-class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
-    def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
-        self.warm_up = int(warm_up * T_max)
-        self.T_max = T_max - self.warm_up
-        super().__init__(optimizer, last_epoch=last_epoch)
-
-    def get_lr(self):
-        if not self._get_lr_called_within_step:
-            warnings.warn("To get the last learning rate computed by the scheduler, "
-                          "please use `get_last_lr()`.")
-
-        if self.last_epoch == 0:
-            return [lr / (self.warm_up + 1) for lr in self.base_lrs]
-        elif self.last_epoch <= self.warm_up:
-            c = (self.last_epoch + 1) / self.last_epoch
-            return [group['lr'] * c for group in self.optimizer.param_groups]
-        else:
-            # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
-            le = self.last_epoch - self.warm_up
-            return [(1 + np.cos(np.pi * le / self.T_max)) /
-                    (1 + np.cos(np.pi * (le - 1) / self.T_max)) *
-                    group['lr']
-                    for group in self.optimizer.param_groups]
diff --git a/supervised/schedulers.py b/supervised/schedulers.py
new file mode 100644
index 0000000..828e547
--- /dev/null
+++ b/supervised/schedulers.py
@@ -0,0 +1,29 @@
+import warnings
+
+import numpy as np
+import torch
+
+
+class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
+    def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
+        self.warm_up = int(warm_up * T_max)
+        self.T_max = T_max - self.warm_up
+        super().__init__(optimizer, last_epoch=last_epoch)
+
+    def get_lr(self):
+        if not self._get_lr_called_within_step:
+            warnings.warn("To get the last learning rate computed by the scheduler, "
+                          "please use `get_last_lr()`.")
+
+        if self.last_epoch == 0:
+            return [lr / (self.warm_up + 1) for lr in self.base_lrs]
+        elif self.last_epoch <= self.warm_up:
+            c = (self.last_epoch + 1) / self.last_epoch
+            return [group['lr'] * c for group in self.optimizer.param_groups]
+        else:
+            # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
+            le = self.last_epoch - self.warm_up
+            return [(1 + np.cos(np.pi * le / self.T_max)) /
+                    (1 + np.cos(np.pi * (le - 1) / self.T_max)) *
+                    group['lr']
+                    for group in self.optimizer.param_groups]
-- 
cgit v1.2.3