aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py191
-rw-r--r--supervised/lars_optimizer.py70
-rw-r--r--supervised/scheduler.py29
3 files changed, 290 insertions, 0 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
new file mode 100644
index 0000000..047d8ca
--- /dev/null
+++ b/supervised/baseline.py
@@ -0,0 +1,191 @@
+import os
+import random
+
+import torch
+from torch import nn, Tensor, optim
+from torch.nn import CrossEntropyLoss
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.datasets import CIFAR10
+from torchvision.models import ResNet
+from torchvision.models.resnet import BasicBlock
+from torchvision.transforms import transforms, InterpolationMode
+from tqdm import tqdm
+
+from supervised.lars_optimizer import LARS
+from supervised.scheduler import LinearWarmupAndCosineAnneal
+
+CODENAME = 'cifar10-resnet50-aug-lars-sched'
+DATASET_ROOT = '../dataset'
+TENSORBOARD_PATH = os.path.join('../runs', CODENAME)
+CHECKPOINT_PATH = os.path.join('../checkpoints', CODENAME)
+
+BATCH_SIZE = 256
+N_EPOCHS = 1000
+WARMUP_EPOCHS = 10
+N_WORKERS = 2
+LR = 1
+MOMENTUM = 0.9
+WEIGHT_DECAY = 1e-6
+SEED = 0
+
+if not os.path.exists(CHECKPOINT_PATH):
+ os.makedirs(CHECKPOINT_PATH)
+
+random.seed(SEED)
+torch.manual_seed(SEED)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+class CIFAR10ResNet50(ResNet):
+ def __init__(self):
+ super(CIFAR10ResNet50, self).__init__(
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
+ )
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
+ stride=1, padding=1, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+
+def get_color_distortion(s=1.0):
+ # s is the strength of color distortion.
+ color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+ rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+ rnd_gray = transforms.RandomGrayscale(p=0.2)
+ color_distort = transforms.Compose([
+ rnd_color_jitter,
+ rnd_gray
+ ])
+ return color_distort
+
+
+transform = transforms.Compose([
+ transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(0.5),
+ get_color_distortion(0.5),
+ transforms.ToTensor()
+])
+
+train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True)
+test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True)
+
+train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
+ shuffle=True, num_workers=N_WORKERS)
+test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
+ shuffle=False, num_workers=N_WORKERS)
+
+num_train_batches = len(train_loader)
+num_test_batches = len(test_loader)
+
+resnet = CIFAR10ResNet50().to(device)
+criterion = CrossEntropyLoss()
+
+
+def exclude_from_wd_and_adaptation(name):
+ if 'bn' in name or 'bias' in name:
+ return True
+
+
+param_groups = [
+ {
+ 'params': [p for name, p in resnet.named_parameters()
+ if not exclude_from_wd_and_adaptation(name)],
+ 'weight_decay': WEIGHT_DECAY,
+ 'layer_adaptation': True,
+ },
+ {
+ 'params': [p for name, p in resnet.named_parameters()
+ if exclude_from_wd_and_adaptation(name)],
+ 'weight_decay': 0.,
+ 'layer_adaptation': False,
+ },
+]
+optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
+scheduler = LinearWarmupAndCosineAnneal(
+ optimizer,
+ WARMUP_EPOCHS / N_EPOCHS,
+ N_EPOCHS * num_train_batches,
+ last_epoch=-1,
+)
+optimizer = LARS(optimizer)
+
+writer = SummaryWriter(TENSORBOARD_PATH)
+
+train_iters = 0
+test_iters = 0
+for epoch in range(N_EPOCHS):
+ train_loss = 0
+ training_progress = tqdm(
+ enumerate(train_loader), desc='Train loss: ', total=num_train_batches
+ )
+
+ resnet.train()
+ for batch, (images, targets) in training_progress:
+ images, targets = images.to(device), targets.to(device)
+
+ resnet.zero_grad()
+ output = resnet(images)
+ loss = criterion(output, targets)
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+
+ train_loss += loss.item()
+ train_loss_mean = train_loss / (batch + 1)
+ training_progress.set_description(f'Train loss: {train_loss_mean:.4f}')
+ writer.add_scalar('Loss/train', loss, train_iters + 1)
+ train_iters += 1
+
+ test_loss = 0
+ test_acc = 0
+ test_progress = tqdm(
+ enumerate(test_loader), desc='Test loss: ', total=num_test_batches
+ )
+
+ resnet.eval()
+ with torch.no_grad():
+ for batch, (images, targets) in test_progress:
+ images, targets = images.to(device), targets.to(device)
+
+ output = resnet(images)
+ loss = criterion(output, targets)
+ _, prediction = output.max(-1)
+
+ test_loss += loss
+ test_loss_mean = test_loss / (batch + 1)
+ test_progress.set_description(f'Test loss: {test_loss_mean:.4f}')
+ test_acc += (prediction == targets).float().mean()
+ test_acc_mean = test_acc / (batch + 1)
+ writer.add_scalar('Loss/test', loss, test_iters + 1)
+ test_iters += 1
+
+ train_loss_mean = train_loss / num_train_batches
+ test_loss_mean = test_loss / num_test_batches
+ test_acc_mean = test_acc / num_test_batches
+ print(f'[{epoch + 1}/{N_EPOCHS}]\t'
+ f'Train loss: {train_loss_mean:.4f}\t'
+ f'Test loss: {test_loss_mean:.4f}\t',
+ f'Test acc: {test_acc_mean:.4f}')
+
+ writer.add_scalar('Acc', test_acc_mean, epoch + 1)
+
+ torch.save({'epoch': epoch,
+ 'resnet_state_dict': resnet.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'train_loss': train_loss_mean, 'test_loss': test_loss_mean,
+ }, os.path.join(CHECKPOINT_PATH, f'{epoch + 1:04d}.pt'))
diff --git a/supervised/lars_optimizer.py b/supervised/lars_optimizer.py
new file mode 100644
index 0000000..1904e8d
--- /dev/null
+++ b/supervised/lars_optimizer.py
@@ -0,0 +1,70 @@
+import torch
+
+
+class LARS(object):
+ """
+ Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
+ Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
+
+ Args:
+ optimizer: Pytorch optimizer to wrap and modify learning rate for.
+ trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
+ """
+
+ def __init__(self,
+ optimizer,
+ trust_coefficient=0.001,
+ ):
+ self.param_groups = optimizer.param_groups
+ self.optim = optimizer
+ self.trust_coefficient = trust_coefficient
+
+ def __getstate__(self):
+ return self.optim.__getstate__()
+
+ def __setstate__(self, state):
+ self.optim.__setstate__(state)
+
+ def __repr__(self):
+ return self.optim.__repr__()
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ self.optim.zero_grad()
+
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)
+
+ def step(self):
+ with torch.no_grad():
+ weight_decays = []
+ for group in self.optim.param_groups:
+ # absorb weight decay control from optimizer
+ weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
+ weight_decays.append(weight_decay)
+ group['weight_decay'] = 0
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ if weight_decay != 0:
+ p.grad.data += weight_decay * p.data
+
+ param_norm = torch.norm(p.data)
+ grad_norm = torch.norm(p.grad.data)
+ adaptive_lr = 1.
+
+ if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
+ adaptive_lr = self.trust_coefficient * param_norm / grad_norm
+
+ p.grad.data *= adaptive_lr
+
+ self.optim.step()
+ # return weight decay control to optimizer
+ for i, group in enumerate(self.optim.param_groups):
+ group['weight_decay'] = weight_decays[i]
diff --git a/supervised/scheduler.py b/supervised/scheduler.py
new file mode 100644
index 0000000..828e547
--- /dev/null
+++ b/supervised/scheduler.py
@@ -0,0 +1,29 @@
+import warnings
+
+import numpy as np
+import torch
+
+
+class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(self, optimizer, warm_up, T_max, last_epoch=-1):
+ self.warm_up = int(warm_up * T_max)
+ self.T_max = T_max - self.warm_up
+ super().__init__(optimizer, last_epoch=last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ if self.last_epoch == 0:
+ return [lr / (self.warm_up + 1) for lr in self.base_lrs]
+ elif self.last_epoch <= self.warm_up:
+ c = (self.last_epoch + 1) / self.last_epoch
+ return [group['lr'] * c for group in self.optimizer.param_groups]
+ else:
+ # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
+ le = self.last_epoch - self.warm_up
+ return [(1 + np.cos(np.pi * le / self.T_max)) /
+ (1 + np.cos(np.pi * (le - 1) / self.T_max)) *
+ group['lr']
+ for group in self.optimizer.param_groups]