diff options
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 191 | ||||
-rw-r--r-- | supervised/lars_optimizer.py | 70 | ||||
-rw-r--r-- | supervised/scheduler.py | 29 |
3 files changed, 290 insertions, 0 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py new file mode 100644 index 0000000..047d8ca --- /dev/null +++ b/supervised/baseline.py @@ -0,0 +1,191 @@ +import os +import random + +import torch +from torch import nn, Tensor, optim +from torch.nn import CrossEntropyLoss +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torchvision.datasets import CIFAR10 +from torchvision.models import ResNet +from torchvision.models.resnet import BasicBlock +from torchvision.transforms import transforms, InterpolationMode +from tqdm import tqdm + +from supervised.lars_optimizer import LARS +from supervised.scheduler import LinearWarmupAndCosineAnneal + +CODENAME = 'cifar10-resnet50-aug-lars-sched' +DATASET_ROOT = '../dataset' +TENSORBOARD_PATH = os.path.join('../runs', CODENAME) +CHECKPOINT_PATH = os.path.join('../checkpoints', CODENAME) + +BATCH_SIZE = 256 +N_EPOCHS = 1000 +WARMUP_EPOCHS = 10 +N_WORKERS = 2 +LR = 1 +MOMENTUM = 0.9 +WEIGHT_DECAY = 1e-6 +SEED = 0 + +if not os.path.exists(CHECKPOINT_PATH): + os.makedirs(CHECKPOINT_PATH) + +random.seed(SEED) +torch.manual_seed(SEED) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class CIFAR10ResNet50(ResNet): + def __init__(self): + super(CIFAR10ResNet50, self).__init__( + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + ) + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +def get_color_distortion(s=1.0): + # s is the strength of color distortion. + color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([ + rnd_color_jitter, + rnd_gray + ]) + return color_distort + + +transform = transforms.Compose([ + transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(0.5), + get_color_distortion(0.5), + transforms.ToTensor() +]) + +train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True) +test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True) + +train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, + shuffle=True, num_workers=N_WORKERS) +test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, + shuffle=False, num_workers=N_WORKERS) + +num_train_batches = len(train_loader) +num_test_batches = len(test_loader) + +resnet = CIFAR10ResNet50().to(device) +criterion = CrossEntropyLoss() + + +def exclude_from_wd_and_adaptation(name): + if 'bn' in name or 'bias' in name: + return True + + +param_groups = [ + { + 'params': [p for name, p in resnet.named_parameters() + if not exclude_from_wd_and_adaptation(name)], + 'weight_decay': WEIGHT_DECAY, + 'layer_adaptation': True, + }, + { + 'params': [p for name, p in resnet.named_parameters() + if exclude_from_wd_and_adaptation(name)], + 'weight_decay': 0., + 'layer_adaptation': False, + }, +] +optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) +scheduler = LinearWarmupAndCosineAnneal( + optimizer, + WARMUP_EPOCHS / N_EPOCHS, + N_EPOCHS * num_train_batches, + last_epoch=-1, +) +optimizer = LARS(optimizer) + +writer = SummaryWriter(TENSORBOARD_PATH) + +train_iters = 0 +test_iters = 0 +for epoch in range(N_EPOCHS): + train_loss = 0 + training_progress = tqdm( + enumerate(train_loader), desc='Train loss: ', total=num_train_batches + ) + + resnet.train() + for batch, (images, targets) in training_progress: + images, targets = images.to(device), targets.to(device) + + resnet.zero_grad() + output = resnet(images) + loss = criterion(output, targets) + loss.backward() + optimizer.step() + scheduler.step() + + train_loss += loss.item() + train_loss_mean = train_loss / (batch + 1) + training_progress.set_description(f'Train loss: {train_loss_mean:.4f}') + writer.add_scalar('Loss/train', loss, train_iters + 1) + train_iters += 1 + + test_loss = 0 + test_acc = 0 + test_progress = tqdm( + enumerate(test_loader), desc='Test loss: ', total=num_test_batches + ) + + resnet.eval() + with torch.no_grad(): + for batch, (images, targets) in test_progress: + images, targets = images.to(device), targets.to(device) + + output = resnet(images) + loss = criterion(output, targets) + _, prediction = output.max(-1) + + test_loss += loss + test_loss_mean = test_loss / (batch + 1) + test_progress.set_description(f'Test loss: {test_loss_mean:.4f}') + test_acc += (prediction == targets).float().mean() + test_acc_mean = test_acc / (batch + 1) + writer.add_scalar('Loss/test', loss, test_iters + 1) + test_iters += 1 + + train_loss_mean = train_loss / num_train_batches + test_loss_mean = test_loss / num_test_batches + test_acc_mean = test_acc / num_test_batches + print(f'[{epoch + 1}/{N_EPOCHS}]\t' + f'Train loss: {train_loss_mean:.4f}\t' + f'Test loss: {test_loss_mean:.4f}\t', + f'Test acc: {test_acc_mean:.4f}') + + writer.add_scalar('Acc', test_acc_mean, epoch + 1) + + torch.save({'epoch': epoch, + 'resnet_state_dict': resnet.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': train_loss_mean, 'test_loss': test_loss_mean, + }, os.path.join(CHECKPOINT_PATH, f'{epoch + 1:04d}.pt')) diff --git a/supervised/lars_optimizer.py b/supervised/lars_optimizer.py new file mode 100644 index 0000000..1904e8d --- /dev/null +++ b/supervised/lars_optimizer.py @@ -0,0 +1,70 @@ +import torch + + +class LARS(object): + """ + Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py + Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py + + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 + """ + + def __init__(self, + optimizer, + trust_coefficient=0.001, + ): + self.param_groups = optimizer.param_groups + self.optim = optimizer + self.trust_coefficient = trust_coefficient + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 + weight_decays.append(weight_decay) + group['weight_decay'] = 0 + for p in group['params']: + if p.grad is None: + continue + + if weight_decay != 0: + p.grad.data += weight_decay * p.data + + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + adaptive_lr = 1. + + if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: + adaptive_lr = self.trust_coefficient * param_norm / grad_norm + + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group['weight_decay'] = weight_decays[i] diff --git a/supervised/scheduler.py b/supervised/scheduler.py new file mode 100644 index 0000000..828e547 --- /dev/null +++ b/supervised/scheduler.py @@ -0,0 +1,29 @@ +import warnings + +import numpy as np +import torch + + +class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): + self.warm_up = int(warm_up * T_max) + self.T_max = T_max - self.warm_up + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + if self.last_epoch == 0: + return [lr / (self.warm_up + 1) for lr in self.base_lrs] + elif self.last_epoch <= self.warm_up: + c = (self.last_epoch + 1) / self.last_epoch + return [group['lr'] * c for group in self.optimizer.param_groups] + else: + # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 + le = self.last_epoch - self.warm_up + return [(1 + np.cos(np.pi * le / self.T_max)) / + (1 + np.cos(np.pi * (le - 1) / self.T_max)) * + group['lr'] + for group in self.optimizer.param_groups] |