aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 20:10:42 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 20:10:42 +0800
commit9b2c25d3c927b7533e5d7d9665b67962e4c6934b (patch)
treeb4a80004ffdfe8165f5f8d077afb8b054d4472ce /supervised
parent568569c764ffdd73cd660434df50d30d26203f63 (diff)
Move some utils to libs directory
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py14
-rw-r--r--supervised/datautils.py67
-rw-r--r--supervised/optimizers.py70
-rw-r--r--supervised/schedulers.py43
-rw-r--r--supervised/utils.py68
5 files changed, 10 insertions, 252 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 221b90d..15bb716 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,3 +1,9 @@
+import sys
+from pathlib import Path
+
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
import argparse
import os
import random
@@ -11,11 +17,11 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10, ImageNet
from torchvision.transforms import transforms, InterpolationMode
-from datautils import color_distortion, Clip, RandomGaussianBlur
+from libs.datautils import color_distortion, Clip, RandomGaussianBlur
+from libs.optimizers import LARS
+from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
+from libs.utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER
from models import CIFARResNet50, ImageNetResNet50
-from optimizers import LARS
-from schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER
def build_parser():
diff --git a/supervised/datautils.py b/supervised/datautils.py
deleted file mode 100644
index 843f669..0000000
--- a/supervised/datautils.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import numpy as np
-import torch
-from torchvision.transforms import transforms
-
-
-def color_distortion(s=1.0):
- # s is the strength of color distortion.
- color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
- rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
- rnd_gray = transforms.RandomGrayscale(p=0.2)
- color_distort = transforms.Compose([
- rnd_color_jitter,
- rnd_gray
- ])
- return color_distort
-
-
-class Clip(object):
- def __call__(self, x):
- return torch.clamp(x, 0, 1)
-
-
-class RandomGaussianBlur(object):
- """
- PyTorch version of
- https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311
- """
-
- def gaussian_blur(self, image, sigma):
- image = image.reshape(1, 3, 224, 224)
- radius = np.int(self.kernel_size / 2)
- kernel_size = radius * 2 + 1
- x = np.arange(-radius, radius + 1)
-
- blur_filter = np.exp(
- -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0)))
- blur_filter /= np.sum(blur_filter)
-
- conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3,
- padding=[kernel_size // 2, 0], bias=False)
- conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile(
- blur_filter.reshape(kernel_size, 1, 1, 1), 3
- ).transpose([3, 2, 0, 1])))
-
- conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3,
- padding=[0, kernel_size // 2], bias=False)
- conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile(
- blur_filter.reshape(kernel_size, 1, 1, 1), 3
- ).transpose([3, 2, 1, 0])))
-
- res = conv2(conv1(image))
- assert res.shape == image.shape
- return res[0]
-
- def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5):
- self.kernel_size = kernel_size
- self.sigma_range = sigma_range
- self.p = p
-
- def __call__(self, img):
- with torch.no_grad():
- assert isinstance(img, torch.Tensor)
- if np.random.uniform() < self.p:
- return self.gaussian_blur(
- img, sigma=np.random.uniform(*self.sigma_range)
- )
- return img
diff --git a/supervised/optimizers.py b/supervised/optimizers.py
deleted file mode 100644
index 1904e8d..0000000
--- a/supervised/optimizers.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import torch
-
-
-class LARS(object):
- """
- Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
- Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
-
- Args:
- optimizer: Pytorch optimizer to wrap and modify learning rate for.
- trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
- """
-
- def __init__(self,
- optimizer,
- trust_coefficient=0.001,
- ):
- self.param_groups = optimizer.param_groups
- self.optim = optimizer
- self.trust_coefficient = trust_coefficient
-
- def __getstate__(self):
- return self.optim.__getstate__()
-
- def __setstate__(self, state):
- self.optim.__setstate__(state)
-
- def __repr__(self):
- return self.optim.__repr__()
-
- def state_dict(self):
- return self.optim.state_dict()
-
- def load_state_dict(self, state_dict):
- self.optim.load_state_dict(state_dict)
-
- def zero_grad(self):
- self.optim.zero_grad()
-
- def add_param_group(self, param_group):
- self.optim.add_param_group(param_group)
-
- def step(self):
- with torch.no_grad():
- weight_decays = []
- for group in self.optim.param_groups:
- # absorb weight decay control from optimizer
- weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
- weight_decays.append(weight_decay)
- group['weight_decay'] = 0
- for p in group['params']:
- if p.grad is None:
- continue
-
- if weight_decay != 0:
- p.grad.data += weight_decay * p.data
-
- param_norm = torch.norm(p.data)
- grad_norm = torch.norm(p.grad.data)
- adaptive_lr = 1.
-
- if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
- adaptive_lr = self.trust_coefficient * param_norm / grad_norm
-
- p.grad.data *= adaptive_lr
-
- self.optim.step()
- # return weight decay control to optimizer
- for i, group in enumerate(self.optim.param_groups):
- group['weight_decay'] = weight_decays[i]
diff --git a/supervised/schedulers.py b/supervised/schedulers.py
deleted file mode 100644
index 7580bf3..0000000
--- a/supervised/schedulers.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import warnings
-
-import numpy as np
-import torch
-
-
-class LinearLR(torch.optim.lr_scheduler._LRScheduler):
- def __init__(self, optimizer, num_epochs, last_epoch=-1):
- self.num_epochs = max(num_epochs, 1)
- super().__init__(optimizer, last_epoch)
-
- def get_lr(self):
- res = []
- for lr in self.base_lrs:
- res.append(np.maximum(lr * np.minimum(
- -self.last_epoch * 1. / self.num_epochs + 1., 1.
- ), 0.))
- return res
-
-
-class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
- def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
- self.warm_up = int(warm_up * T_max)
- self.T_max = T_max - self.warm_up
- super().__init__(optimizer, last_epoch=last_epoch)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.")
-
- if self.last_epoch == 0:
- return [lr / (self.warm_up + 1) for lr in self.base_lrs]
- elif self.last_epoch <= self.warm_up:
- c = (self.last_epoch + 1) / self.last_epoch
- return [group['lr'] * c for group in self.optimizer.param_groups]
- else:
- # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
- le = self.last_epoch - self.warm_up
- return [(1 + np.cos(np.pi * le / self.T_max)) /
- (1 + np.cos(np.pi * (le - 1) / self.T_max)) *
- group['lr']
- for group in self.optimizer.param_groups]
diff --git a/supervised/utils.py b/supervised/utils.py
deleted file mode 100644
index fde86eb..0000000
--- a/supervised/utils.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import logging
-import os
-
-EPOCH_LOGGER = 'epoch_logger'
-BATCH_LOGGER = 'batch_logger'
-
-
-class FileHandlerWithHeader(logging.FileHandler):
-
- def __init__(self, filename, header, mode='a', encoding=None, delay=0):
- self.header = header
- self.file_pre_exists = os.path.exists(filename)
-
- logging.FileHandler.__init__(self, filename, mode, encoding, delay)
- if not delay and self.stream is not None:
- self.stream.write(f'{header}\n')
-
- def emit(self, record):
- if self.stream is None:
- self.stream = self._open()
- if not self.file_pre_exists:
- self.stream.write(f'{self.header}\n')
-
- logging.FileHandler.emit(self, record)
-
-
-def setup_logging(name="log",
- filename=None,
- stream_log_level="INFO",
- file_log_level="INFO"):
- logger = logging.getLogger(name)
- logger.setLevel("INFO")
- formatter = logging.Formatter(
- '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
- )
- stream_handler = logging.StreamHandler()
- stream_handler.setLevel(getattr(logging, stream_log_level))
- stream_handler.setFormatter(formatter)
- logger.addHandler(stream_handler)
- if filename is not None:
- header = 'time,logger,'
- if name == BATCH_LOGGER:
- header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
- elif name == EPOCH_LOGGER:
- header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
- else:
- raise NotImplementedError(f"Logger '{name}' is not implemented.")
-
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- file_handler = FileHandlerWithHeader(filename, header)
- file_handler.setLevel(getattr(logging, file_log_level))
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- return logger
-
-
-def training_log(name):
- def log_this(function):
- logger = logging.getLogger(name)
-
- def wrapper(*args, **kwargs):
- output = function(*args, **kwargs)
- logger.info(','.join(map(str, output.values())))
- return output
-
- return wrapper
-
- return log_this