From 9b2c25d3c927b7533e5d7d9665b67962e4c6934b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 17 Mar 2022 20:10:42 +0800 Subject: Move some utils to libs directory --- .idea/contrastive-learning.iml | 7 +++++ libs/datautils.py | 67 ++++++++++++++++++++++++++++++++++++++++ libs/optimizers.py | 70 ++++++++++++++++++++++++++++++++++++++++++ libs/schedulers.py | 43 ++++++++++++++++++++++++++ libs/utils.py | 68 ++++++++++++++++++++++++++++++++++++++++ supervised/baseline.py | 14 ++++++--- supervised/datautils.py | 67 ---------------------------------------- supervised/optimizers.py | 70 ------------------------------------------ supervised/schedulers.py | 43 -------------------------- supervised/utils.py | 68 ---------------------------------------- 10 files changed, 265 insertions(+), 252 deletions(-) create mode 100644 libs/datautils.py create mode 100644 libs/optimizers.py create mode 100644 libs/schedulers.py create mode 100644 libs/utils.py delete mode 100644 supervised/datautils.py delete mode 100644 supervised/optimizers.py delete mode 100644 supervised/schedulers.py delete mode 100644 supervised/utils.py diff --git a/.idea/contrastive-learning.iml b/.idea/contrastive-learning.iml index af873e6..728ebac 100644 --- a/.idea/contrastive-learning.iml +++ b/.idea/contrastive-learning.iml @@ -12,4 +12,11 @@ + + + \ No newline at end of file diff --git a/libs/datautils.py b/libs/datautils.py new file mode 100644 index 0000000..843f669 --- /dev/null +++ b/libs/datautils.py @@ -0,0 +1,67 @@ +import numpy as np +import torch +from torchvision.transforms import transforms + + +def color_distortion(s=1.0): + # s is the strength of color distortion. + color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([ + rnd_color_jitter, + rnd_gray + ]) + return color_distort + + +class Clip(object): + def __call__(self, x): + return torch.clamp(x, 0, 1) + + +class RandomGaussianBlur(object): + """ + PyTorch version of + https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311 + """ + + def gaussian_blur(self, image, sigma): + image = image.reshape(1, 3, 224, 224) + radius = np.int(self.kernel_size / 2) + kernel_size = radius * 2 + 1 + x = np.arange(-radius, radius + 1) + + blur_filter = np.exp( + -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0))) + blur_filter /= np.sum(blur_filter) + + conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, + padding=[kernel_size // 2, 0], bias=False) + conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile( + blur_filter.reshape(kernel_size, 1, 1, 1), 3 + ).transpose([3, 2, 0, 1]))) + + conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, + padding=[0, kernel_size // 2], bias=False) + conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile( + blur_filter.reshape(kernel_size, 1, 1, 1), 3 + ).transpose([3, 2, 1, 0]))) + + res = conv2(conv1(image)) + assert res.shape == image.shape + return res[0] + + def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5): + self.kernel_size = kernel_size + self.sigma_range = sigma_range + self.p = p + + def __call__(self, img): + with torch.no_grad(): + assert isinstance(img, torch.Tensor) + if np.random.uniform() < self.p: + return self.gaussian_blur( + img, sigma=np.random.uniform(*self.sigma_range) + ) + return img diff --git a/libs/optimizers.py b/libs/optimizers.py new file mode 100644 index 0000000..1904e8d --- /dev/null +++ b/libs/optimizers.py @@ -0,0 +1,70 @@ +import torch + + +class LARS(object): + """ + Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py + Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py + + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 + """ + + def __init__(self, + optimizer, + trust_coefficient=0.001, + ): + self.param_groups = optimizer.param_groups + self.optim = optimizer + self.trust_coefficient = trust_coefficient + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 + weight_decays.append(weight_decay) + group['weight_decay'] = 0 + for p in group['params']: + if p.grad is None: + continue + + if weight_decay != 0: + p.grad.data += weight_decay * p.data + + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + adaptive_lr = 1. + + if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: + adaptive_lr = self.trust_coefficient * param_norm / grad_norm + + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group['weight_decay'] = weight_decays[i] diff --git a/libs/schedulers.py b/libs/schedulers.py new file mode 100644 index 0000000..7580bf3 --- /dev/null +++ b/libs/schedulers.py @@ -0,0 +1,43 @@ +import warnings + +import numpy as np +import torch + + +class LinearLR(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, num_epochs, last_epoch=-1): + self.num_epochs = max(num_epochs, 1) + super().__init__(optimizer, last_epoch) + + def get_lr(self): + res = [] + for lr in self.base_lrs: + res.append(np.maximum(lr * np.minimum( + -self.last_epoch * 1. / self.num_epochs + 1., 1. + ), 0.)) + return res + + +class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): + self.warm_up = int(warm_up * T_max) + self.T_max = T_max - self.warm_up + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + if self.last_epoch == 0: + return [lr / (self.warm_up + 1) for lr in self.base_lrs] + elif self.last_epoch <= self.warm_up: + c = (self.last_epoch + 1) / self.last_epoch + return [group['lr'] * c for group in self.optimizer.param_groups] + else: + # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 + le = self.last_epoch - self.warm_up + return [(1 + np.cos(np.pi * le / self.T_max)) / + (1 + np.cos(np.pi * (le - 1) / self.T_max)) * + group['lr'] + for group in self.optimizer.param_groups] diff --git a/libs/utils.py b/libs/utils.py new file mode 100644 index 0000000..fde86eb --- /dev/null +++ b/libs/utils.py @@ -0,0 +1,68 @@ +import logging +import os + +EPOCH_LOGGER = 'epoch_logger' +BATCH_LOGGER = 'batch_logger' + + +class FileHandlerWithHeader(logging.FileHandler): + + def __init__(self, filename, header, mode='a', encoding=None, delay=0): + self.header = header + self.file_pre_exists = os.path.exists(filename) + + logging.FileHandler.__init__(self, filename, mode, encoding, delay) + if not delay and self.stream is not None: + self.stream.write(f'{header}\n') + + def emit(self, record): + if self.stream is None: + self.stream = self._open() + if not self.file_pre_exists: + self.stream.write(f'{self.header}\n') + + logging.FileHandler.emit(self, record) + + +def setup_logging(name="log", + filename=None, + stream_log_level="INFO", + file_log_level="INFO"): + logger = logging.getLogger(name) + logger.setLevel("INFO") + formatter = logging.Formatter( + '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' + ) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(getattr(logging, stream_log_level)) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + if filename is not None: + header = 'time,logger,' + if name == BATCH_LOGGER: + header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' + elif name == EPOCH_LOGGER: + header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' + else: + raise NotImplementedError(f"Logger '{name}' is not implemented.") + + os.makedirs(os.path.dirname(filename), exist_ok=True) + file_handler = FileHandlerWithHeader(filename, header) + file_handler.setLevel(getattr(logging, file_log_level)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + return logger + + +def training_log(name): + def log_this(function): + logger = logging.getLogger(name) + + def wrapper(*args, **kwargs): + output = function(*args, **kwargs) + logger.info(','.join(map(str, output.values()))) + return output + + return wrapper + + return log_this diff --git a/supervised/baseline.py b/supervised/baseline.py index 221b90d..15bb716 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,3 +1,9 @@ +import sys +from pathlib import Path + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + import argparse import os import random @@ -11,11 +17,11 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.datasets import CIFAR10, ImageNet from torchvision.transforms import transforms, InterpolationMode -from datautils import color_distortion, Clip, RandomGaussianBlur +from libs.datautils import color_distortion, Clip, RandomGaussianBlur +from libs.optimizers import LARS +from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR +from libs.utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER from models import CIFARResNet50, ImageNetResNet50 -from optimizers import LARS -from schedulers import LinearWarmupAndCosineAnneal, LinearLR -from utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER def build_parser(): diff --git a/supervised/datautils.py b/supervised/datautils.py deleted file mode 100644 index 843f669..0000000 --- a/supervised/datautils.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -import torch -from torchvision.transforms import transforms - - -def color_distortion(s=1.0): - # s is the strength of color distortion. - color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) - rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) - rnd_gray = transforms.RandomGrayscale(p=0.2) - color_distort = transforms.Compose([ - rnd_color_jitter, - rnd_gray - ]) - return color_distort - - -class Clip(object): - def __call__(self, x): - return torch.clamp(x, 0, 1) - - -class RandomGaussianBlur(object): - """ - PyTorch version of - https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311 - """ - - def gaussian_blur(self, image, sigma): - image = image.reshape(1, 3, 224, 224) - radius = np.int(self.kernel_size / 2) - kernel_size = radius * 2 + 1 - x = np.arange(-radius, radius + 1) - - blur_filter = np.exp( - -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0))) - blur_filter /= np.sum(blur_filter) - - conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, - padding=[kernel_size // 2, 0], bias=False) - conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile( - blur_filter.reshape(kernel_size, 1, 1, 1), 3 - ).transpose([3, 2, 0, 1]))) - - conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, - padding=[0, kernel_size // 2], bias=False) - conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile( - blur_filter.reshape(kernel_size, 1, 1, 1), 3 - ).transpose([3, 2, 1, 0]))) - - res = conv2(conv1(image)) - assert res.shape == image.shape - return res[0] - - def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5): - self.kernel_size = kernel_size - self.sigma_range = sigma_range - self.p = p - - def __call__(self, img): - with torch.no_grad(): - assert isinstance(img, torch.Tensor) - if np.random.uniform() < self.p: - return self.gaussian_blur( - img, sigma=np.random.uniform(*self.sigma_range) - ) - return img diff --git a/supervised/optimizers.py b/supervised/optimizers.py deleted file mode 100644 index 1904e8d..0000000 --- a/supervised/optimizers.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch - - -class LARS(object): - """ - Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py - Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py - - Args: - optimizer: Pytorch optimizer to wrap and modify learning rate for. - trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 - """ - - def __init__(self, - optimizer, - trust_coefficient=0.001, - ): - self.param_groups = optimizer.param_groups - self.optim = optimizer - self.trust_coefficient = trust_coefficient - - def __getstate__(self): - return self.optim.__getstate__() - - def __setstate__(self, state): - self.optim.__setstate__(state) - - def __repr__(self): - return self.optim.__repr__() - - def state_dict(self): - return self.optim.state_dict() - - def load_state_dict(self, state_dict): - self.optim.load_state_dict(state_dict) - - def zero_grad(self): - self.optim.zero_grad() - - def add_param_group(self, param_group): - self.optim.add_param_group(param_group) - - def step(self): - with torch.no_grad(): - weight_decays = [] - for group in self.optim.param_groups: - # absorb weight decay control from optimizer - weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 - weight_decays.append(weight_decay) - group['weight_decay'] = 0 - for p in group['params']: - if p.grad is None: - continue - - if weight_decay != 0: - p.grad.data += weight_decay * p.data - - param_norm = torch.norm(p.data) - grad_norm = torch.norm(p.grad.data) - adaptive_lr = 1. - - if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: - adaptive_lr = self.trust_coefficient * param_norm / grad_norm - - p.grad.data *= adaptive_lr - - self.optim.step() - # return weight decay control to optimizer - for i, group in enumerate(self.optim.param_groups): - group['weight_decay'] = weight_decays[i] diff --git a/supervised/schedulers.py b/supervised/schedulers.py deleted file mode 100644 index 7580bf3..0000000 --- a/supervised/schedulers.py +++ /dev/null @@ -1,43 +0,0 @@ -import warnings - -import numpy as np -import torch - - -class LinearLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, num_epochs, last_epoch=-1): - self.num_epochs = max(num_epochs, 1) - super().__init__(optimizer, last_epoch) - - def get_lr(self): - res = [] - for lr in self.base_lrs: - res.append(np.maximum(lr * np.minimum( - -self.last_epoch * 1. / self.num_epochs + 1., 1. - ), 0.)) - return res - - -class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): - self.warm_up = int(warm_up * T_max) - self.T_max = T_max - self.warm_up - super().__init__(optimizer, last_epoch=last_epoch) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") - - if self.last_epoch == 0: - return [lr / (self.warm_up + 1) for lr in self.base_lrs] - elif self.last_epoch <= self.warm_up: - c = (self.last_epoch + 1) / self.last_epoch - return [group['lr'] * c for group in self.optimizer.param_groups] - else: - # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 - le = self.last_epoch - self.warm_up - return [(1 + np.cos(np.pi * le / self.T_max)) / - (1 + np.cos(np.pi * (le - 1) / self.T_max)) * - group['lr'] - for group in self.optimizer.param_groups] diff --git a/supervised/utils.py b/supervised/utils.py deleted file mode 100644 index fde86eb..0000000 --- a/supervised/utils.py +++ /dev/null @@ -1,68 +0,0 @@ -import logging -import os - -EPOCH_LOGGER = 'epoch_logger' -BATCH_LOGGER = 'batch_logger' - - -class FileHandlerWithHeader(logging.FileHandler): - - def __init__(self, filename, header, mode='a', encoding=None, delay=0): - self.header = header - self.file_pre_exists = os.path.exists(filename) - - logging.FileHandler.__init__(self, filename, mode, encoding, delay) - if not delay and self.stream is not None: - self.stream.write(f'{header}\n') - - def emit(self, record): - if self.stream is None: - self.stream = self._open() - if not self.file_pre_exists: - self.stream.write(f'{self.header}\n') - - logging.FileHandler.emit(self, record) - - -def setup_logging(name="log", - filename=None, - stream_log_level="INFO", - file_log_level="INFO"): - logger = logging.getLogger(name) - logger.setLevel("INFO") - formatter = logging.Formatter( - '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' - ) - stream_handler = logging.StreamHandler() - stream_handler.setLevel(getattr(logging, stream_log_level)) - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) - if filename is not None: - header = 'time,logger,' - if name == BATCH_LOGGER: - header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' - elif name == EPOCH_LOGGER: - header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' - else: - raise NotImplementedError(f"Logger '{name}' is not implemented.") - - os.makedirs(os.path.dirname(filename), exist_ok=True) - file_handler = FileHandlerWithHeader(filename, header) - file_handler.setLevel(getattr(logging, file_log_level)) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - return logger - - -def training_log(name): - def log_this(function): - logger = logging.getLogger(name) - - def wrapper(*args, **kwargs): - output = function(*args, **kwargs) - logger.info(','.join(map(str, output.values()))) - return output - - return wrapper - - return log_this -- cgit v1.2.3