diff options
Diffstat (limited to 'libs')
| -rw-r--r-- | libs/datautils.py | 67 | ||||
| -rw-r--r-- | libs/optimizers.py | 70 | ||||
| -rw-r--r-- | libs/schedulers.py | 43 | ||||
| -rw-r--r-- | libs/utils.py | 68 | 
4 files changed, 248 insertions, 0 deletions
| diff --git a/libs/datautils.py b/libs/datautils.py new file mode 100644 index 0000000..843f669 --- /dev/null +++ b/libs/datautils.py @@ -0,0 +1,67 @@ +import numpy as np +import torch +from torchvision.transforms import transforms + + +def color_distortion(s=1.0): +    # s is the strength of color distortion. +    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) +    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) +    rnd_gray = transforms.RandomGrayscale(p=0.2) +    color_distort = transforms.Compose([ +        rnd_color_jitter, +        rnd_gray +    ]) +    return color_distort + + +class Clip(object): +    def __call__(self, x): +        return torch.clamp(x, 0, 1) + + +class RandomGaussianBlur(object): +    """ +        PyTorch version of +        https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311 +    """ + +    def gaussian_blur(self, image, sigma): +        image = image.reshape(1, 3, 224, 224) +        radius = np.int(self.kernel_size / 2) +        kernel_size = radius * 2 + 1 +        x = np.arange(-radius, radius + 1) + +        blur_filter = np.exp( +            -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0))) +        blur_filter /= np.sum(blur_filter) + +        conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, +                                padding=[kernel_size // 2, 0], bias=False) +        conv1.weight = torch.nn.Parameter(torch.Tensor(np.tile( +            blur_filter.reshape(kernel_size, 1, 1, 1), 3 +        ).transpose([3, 2, 0, 1]))) + +        conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, +                                padding=[0, kernel_size // 2], bias=False) +        conv2.weight = torch.nn.Parameter(torch.Tensor(np.tile( +            blur_filter.reshape(kernel_size, 1, 1, 1), 3 +        ).transpose([3, 2, 1, 0]))) + +        res = conv2(conv1(image)) +        assert res.shape == image.shape +        return res[0] + +    def __init__(self, kernel_size, sigma_range=(0.1, 2), p=0.5): +        self.kernel_size = kernel_size +        self.sigma_range = sigma_range +        self.p = p + +    def __call__(self, img): +        with torch.no_grad(): +            assert isinstance(img, torch.Tensor) +            if np.random.uniform() < self.p: +                return self.gaussian_blur( +                    img, sigma=np.random.uniform(*self.sigma_range) +                ) +            return img diff --git a/libs/optimizers.py b/libs/optimizers.py new file mode 100644 index 0000000..1904e8d --- /dev/null +++ b/libs/optimizers.py @@ -0,0 +1,70 @@ +import torch + + +class LARS(object): +    """ +    Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py +    Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py + +    Args: +        optimizer: Pytorch optimizer to wrap and modify learning rate for. +        trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 +    """ + +    def __init__(self, +                 optimizer, +                 trust_coefficient=0.001, +                 ): +        self.param_groups = optimizer.param_groups +        self.optim = optimizer +        self.trust_coefficient = trust_coefficient + +    def __getstate__(self): +        return self.optim.__getstate__() + +    def __setstate__(self, state): +        self.optim.__setstate__(state) + +    def __repr__(self): +        return self.optim.__repr__() + +    def state_dict(self): +        return self.optim.state_dict() + +    def load_state_dict(self, state_dict): +        self.optim.load_state_dict(state_dict) + +    def zero_grad(self): +        self.optim.zero_grad() + +    def add_param_group(self, param_group): +        self.optim.add_param_group(param_group) + +    def step(self): +        with torch.no_grad(): +            weight_decays = [] +            for group in self.optim.param_groups: +                # absorb weight decay control from optimizer +                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 +                weight_decays.append(weight_decay) +                group['weight_decay'] = 0 +                for p in group['params']: +                    if p.grad is None: +                        continue + +                    if weight_decay != 0: +                        p.grad.data += weight_decay * p.data + +                    param_norm = torch.norm(p.data) +                    grad_norm = torch.norm(p.grad.data) +                    adaptive_lr = 1. + +                    if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: +                        adaptive_lr = self.trust_coefficient * param_norm / grad_norm + +                    p.grad.data *= adaptive_lr + +        self.optim.step() +        # return weight decay control to optimizer +        for i, group in enumerate(self.optim.param_groups): +            group['weight_decay'] = weight_decays[i] diff --git a/libs/schedulers.py b/libs/schedulers.py new file mode 100644 index 0000000..7580bf3 --- /dev/null +++ b/libs/schedulers.py @@ -0,0 +1,43 @@ +import warnings + +import numpy as np +import torch + + +class LinearLR(torch.optim.lr_scheduler._LRScheduler): +    def __init__(self, optimizer, num_epochs, last_epoch=-1): +        self.num_epochs = max(num_epochs, 1) +        super().__init__(optimizer, last_epoch) + +    def get_lr(self): +        res = [] +        for lr in self.base_lrs: +            res.append(np.maximum(lr * np.minimum( +                -self.last_epoch * 1. / self.num_epochs + 1., 1. +            ), 0.)) +        return res + + +class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): +    def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): +        self.warm_up = int(warm_up * T_max) +        self.T_max = T_max - self.warm_up +        super().__init__(optimizer, last_epoch=last_epoch) + +    def get_lr(self): +        if not self._get_lr_called_within_step: +            warnings.warn("To get the last learning rate computed by the scheduler, " +                          "please use `get_last_lr()`.") + +        if self.last_epoch == 0: +            return [lr / (self.warm_up + 1) for lr in self.base_lrs] +        elif self.last_epoch <= self.warm_up: +            c = (self.last_epoch + 1) / self.last_epoch +            return [group['lr'] * c for group in self.optimizer.param_groups] +        else: +            # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 +            le = self.last_epoch - self.warm_up +            return [(1 + np.cos(np.pi * le / self.T_max)) / +                    (1 + np.cos(np.pi * (le - 1) / self.T_max)) * +                    group['lr'] +                    for group in self.optimizer.param_groups] diff --git a/libs/utils.py b/libs/utils.py new file mode 100644 index 0000000..fde86eb --- /dev/null +++ b/libs/utils.py @@ -0,0 +1,68 @@ +import logging +import os + +EPOCH_LOGGER = 'epoch_logger' +BATCH_LOGGER = 'batch_logger' + + +class FileHandlerWithHeader(logging.FileHandler): + +    def __init__(self, filename, header, mode='a', encoding=None, delay=0): +        self.header = header +        self.file_pre_exists = os.path.exists(filename) + +        logging.FileHandler.__init__(self, filename, mode, encoding, delay) +        if not delay and self.stream is not None: +            self.stream.write(f'{header}\n') + +    def emit(self, record): +        if self.stream is None: +            self.stream = self._open() +            if not self.file_pre_exists: +                self.stream.write(f'{self.header}\n') + +        logging.FileHandler.emit(self, record) + + +def setup_logging(name="log", +                  filename=None, +                  stream_log_level="INFO", +                  file_log_level="INFO"): +    logger = logging.getLogger(name) +    logger.setLevel("INFO") +    formatter = logging.Formatter( +        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' +    ) +    stream_handler = logging.StreamHandler() +    stream_handler.setLevel(getattr(logging, stream_log_level)) +    stream_handler.setFormatter(formatter) +    logger.addHandler(stream_handler) +    if filename is not None: +        header = 'time,logger,' +        if name == BATCH_LOGGER: +            header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' +        elif name == EPOCH_LOGGER: +            header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' +        else: +            raise NotImplementedError(f"Logger '{name}' is not implemented.") + +        os.makedirs(os.path.dirname(filename), exist_ok=True) +        file_handler = FileHandlerWithHeader(filename, header) +        file_handler.setLevel(getattr(logging, file_log_level)) +        file_handler.setFormatter(formatter) +        logger.addHandler(file_handler) +    return logger + + +def training_log(name): +    def log_this(function): +        logger = logging.getLogger(name) + +        def wrapper(*args, **kwargs): +            output = function(*args, **kwargs) +            logger.info(','.join(map(str, output.values()))) +            return output + +        return wrapper + +    return log_this | 
