aboutsummaryrefslogtreecommitdiff
path: root/libs
diff options
context:
space:
mode:
Diffstat (limited to 'libs')
-rw-r--r--libs/datautils.py67
-rw-r--r--libs/optimizers.py70
-rw-r--r--libs/schedulers.py43
-rw-r--r--libs/utils.py68
4 files changed, 248 insertions, 0 deletions
diff --git a/libs/datautils.py b/libs/datautils.py
new file mode 100644
index 0000000..843f669
--- /dev/null
+++ b/libs/datautils.py
@@ -0,0 +1,67 @@
+import numpy as np
+import torch
+from torchvision.transforms import transforms
+
+
+def color_distortion(s=1.0):
+ # s is the strength of color distortion.
+ color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+ rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+ rnd_gray = transforms.RandomGrayscale(p=0.2)
+ color_distort = transforms.Compose([
+ rnd_color_jitter,
+ rnd_gray
+ ])
+ return color_distort
+
+
+class Clip(object):
+ def __call__(self, x):
+ return torch.clamp(x, 0, 1)
+
+
+class RandomGaussianBlur(object):
+ """
+ PyTorch version of
+ https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311
+ """
+
+ def gaussian_blur(self, image, sigma):
+ image = image.reshape(1, 3, 224, 224)
+ radius = np.int(self.kernel_size / 2)
+ kernel_size = radius * 2 + 1
+ x = np.arange(-radius, radius + 1)
+
+ blur_filter = np.exp(
+ -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0)))
+ blur_filter /= np.sum(blur_filter)
+
+ conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3,
+ padding=[kernel_size // 2, 0], bias=False)
+ conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile(
+ blur_filter.reshape(kernel_size, 1, 1, 1), 3
+ ).transpose([3, 2, 0, 1])))
+
+ conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3,
+ padding=[0, kernel_size // 2], bias=False)
+ conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile(
+ blur_filter.reshape(kernel_size, 1, 1, 1), 3
+ ).transpose([3, 2, 1, 0])))
+
+ res = conv2(conv1(image))
+ assert res.shape == image.shape
+ return res[0]
+
+ def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5):
+ self.kernel_size = kernel_size
+ self.sigma_range = sigma_range
+ self.p = p
+
+ def __call__(self, img):
+ with torch.no_grad():
+ assert isinstance(img, torch.Tensor)
+ if np.random.uniform() < self.p:
+ return self.gaussian_blur(
+ img, sigma=np.random.uniform(*self.sigma_range)
+ )
+ return img
diff --git a/libs/optimizers.py b/libs/optimizers.py
new file mode 100644
index 0000000..1904e8d
--- /dev/null
+++ b/libs/optimizers.py
@@ -0,0 +1,70 @@
+import torch
+
+
+class LARS(object):
+ """
+ Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
+ Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
+
+ Args:
+ optimizer: Pytorch optimizer to wrap and modify learning rate for.
+ trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
+ """
+
+ def __init__(self,
+ optimizer,
+ trust_coefficient=0.001,
+ ):
+ self.param_groups = optimizer.param_groups
+ self.optim = optimizer
+ self.trust_coefficient = trust_coefficient
+
+ def __getstate__(self):
+ return self.optim.__getstate__()
+
+ def __setstate__(self, state):
+ self.optim.__setstate__(state)
+
+ def __repr__(self):
+ return self.optim.__repr__()
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ self.optim.zero_grad()
+
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)
+
+ def step(self):
+ with torch.no_grad():
+ weight_decays = []
+ for group in self.optim.param_groups:
+ # absorb weight decay control from optimizer
+ weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
+ weight_decays.append(weight_decay)
+ group['weight_decay'] = 0
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ if weight_decay != 0:
+ p.grad.data += weight_decay * p.data
+
+ param_norm = torch.norm(p.data)
+ grad_norm = torch.norm(p.grad.data)
+ adaptive_lr = 1.
+
+ if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
+ adaptive_lr = self.trust_coefficient * param_norm / grad_norm
+
+ p.grad.data *= adaptive_lr
+
+ self.optim.step()
+ # return weight decay control to optimizer
+ for i, group in enumerate(self.optim.param_groups):
+ group['weight_decay'] = weight_decays[i]
diff --git a/libs/schedulers.py b/libs/schedulers.py
new file mode 100644
index 0000000..7580bf3
--- /dev/null
+++ b/libs/schedulers.py
@@ -0,0 +1,43 @@
+import warnings
+
+import numpy as np
+import torch
+
+
+class LinearLR(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(self, optimizer, num_epochs, last_epoch=-1):
+ self.num_epochs = max(num_epochs, 1)
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ res = []
+ for lr in self.base_lrs:
+ res.append(np.maximum(lr * np.minimum(
+ -self.last_epoch * 1. / self.num_epochs + 1., 1.
+ ), 0.))
+ return res
+
+
+class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
+ self.warm_up = int(warm_up * T_max)
+ self.T_max = T_max - self.warm_up
+ super().__init__(optimizer, last_epoch=last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ if self.last_epoch == 0:
+ return [lr / (self.warm_up + 1) for lr in self.base_lrs]
+ elif self.last_epoch <= self.warm_up:
+ c = (self.last_epoch + 1) / self.last_epoch
+ return [group['lr'] * c for group in self.optimizer.param_groups]
+ else:
+ # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
+ le = self.last_epoch - self.warm_up
+ return [(1 + np.cos(np.pi * le / self.T_max)) /
+ (1 + np.cos(np.pi * (le - 1) / self.T_max)) *
+ group['lr']
+ for group in self.optimizer.param_groups]
diff --git a/libs/utils.py b/libs/utils.py
new file mode 100644
index 0000000..fde86eb
--- /dev/null
+++ b/libs/utils.py
@@ -0,0 +1,68 @@
+import logging
+import os
+
+EPOCH_LOGGER = 'epoch_logger'
+BATCH_LOGGER = 'batch_logger'
+
+
+class FileHandlerWithHeader(logging.FileHandler):
+
+ def __init__(self, filename, header, mode='a', encoding=None, delay=0):
+ self.header = header
+ self.file_pre_exists = os.path.exists(filename)
+
+ logging.FileHandler.__init__(self, filename, mode, encoding, delay)
+ if not delay and self.stream is not None:
+ self.stream.write(f'{header}\n')
+
+ def emit(self, record):
+ if self.stream is None:
+ self.stream = self._open()
+ if not self.file_pre_exists:
+ self.stream.write(f'{self.header}\n')
+
+ logging.FileHandler.emit(self, record)
+
+
+def setup_logging(name="log",
+ filename=None,
+ stream_log_level="INFO",
+ file_log_level="INFO"):
+ logger = logging.getLogger(name)
+ logger.setLevel("INFO")
+ formatter = logging.Formatter(
+ '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
+ )
+ stream_handler = logging.StreamHandler()
+ stream_handler.setLevel(getattr(logging, stream_log_level))
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if filename is not None:
+ header = 'time,logger,'
+ if name == BATCH_LOGGER:
+ header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
+ elif name == EPOCH_LOGGER:
+ header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
+ else:
+ raise NotImplementedError(f"Logger '{name}' is not implemented.")
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ file_handler = FileHandlerWithHeader(filename, header)
+ file_handler.setLevel(getattr(logging, file_log_level))
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ return logger
+
+
+def training_log(name):
+ def log_this(function):
+ logger = logging.getLogger(name)
+
+ def wrapper(*args, **kwargs):
+ output = function(*args, **kwargs)
+ logger.info(','.join(map(str, output.values())))
+ return output
+
+ return wrapper
+
+ return log_this