aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
new file mode 100644
index 0000000..047d8ca
--- /dev/null
+++ b/supervised/baseline.py
@@ -0,0 +1,191 @@
+import os
+import random
+
+import torch
+from torch import nn, Tensor, optim
+from torch.nn import CrossEntropyLoss
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.datasets import CIFAR10
+from torchvision.models import ResNet
+from torchvision.models.resnet import BasicBlock
+from torchvision.transforms import transforms, InterpolationMode
+from tqdm import tqdm
+
+from supervised.lars_optimizer import LARS
+from supervised.scheduler import LinearWarmupAndCosineAnneal
+
+CODENAME = 'cifar10-resnet50-aug-lars-sched'
+DATASET_ROOT = '../dataset'
+TENSORBOARD_PATH = os.path.join('../runs', CODENAME)
+CHECKPOINT_PATH = os.path.join('../checkpoints', CODENAME)
+
+BATCH_SIZE = 256
+N_EPOCHS = 1000
+WARMUP_EPOCHS = 10
+N_WORKERS = 2
+LR = 1
+MOMENTUM = 0.9
+WEIGHT_DECAY = 1e-6
+SEED = 0
+
+if not os.path.exists(CHECKPOINT_PATH):
+ os.makedirs(CHECKPOINT_PATH)
+
+random.seed(SEED)
+torch.manual_seed(SEED)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+class CIFAR10ResNet50(ResNet):
+ def __init__(self):
+ super(CIFAR10ResNet50, self).__init__(
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
+ )
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
+ stride=1, padding=1, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+
+def get_color_distortion(s=1.0):
+ # s is the strength of color distortion.
+ color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+ rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+ rnd_gray = transforms.RandomGrayscale(p=0.2)
+ color_distort = transforms.Compose([
+ rnd_color_jitter,
+ rnd_gray
+ ])
+ return color_distort
+
+
+transform = transforms.Compose([
+ transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(0.5),
+ get_color_distortion(0.5),
+ transforms.ToTensor()
+])
+
+train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True)
+test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True)
+
+train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
+ shuffle=True, num_workers=N_WORKERS)
+test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
+ shuffle=False, num_workers=N_WORKERS)
+
+num_train_batches = len(train_loader)
+num_test_batches = len(test_loader)
+
+resnet = CIFAR10ResNet50().to(device)
+criterion = CrossEntropyLoss()
+
+
+def exclude_from_wd_and_adaptation(name):
+ if 'bn' in name or 'bias' in name:
+ return True
+
+
+param_groups = [
+ {
+ 'params': [p for name, p in resnet.named_parameters()
+ if not exclude_from_wd_and_adaptation(name)],
+ 'weight_decay': WEIGHT_DECAY,
+ 'layer_adaptation': True,
+ },
+ {
+ 'params': [p for name, p in resnet.named_parameters()
+ if exclude_from_wd_and_adaptation(name)],
+ 'weight_decay': 0.,
+ 'layer_adaptation': False,
+ },
+]
+optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM)
+scheduler = LinearWarmupAndCosineAnneal(
+ optimizer,
+ WARMUP_EPOCHS / N_EPOCHS,
+ N_EPOCHS * num_train_batches,
+ last_epoch=-1,
+)
+optimizer = LARS(optimizer)
+
+writer = SummaryWriter(TENSORBOARD_PATH)
+
+train_iters = 0
+test_iters = 0
+for epoch in range(N_EPOCHS):
+ train_loss = 0
+ training_progress = tqdm(
+ enumerate(train_loader), desc='Train loss: ', total=num_train_batches
+ )
+
+ resnet.train()
+ for batch, (images, targets) in training_progress:
+ images, targets = images.to(device), targets.to(device)
+
+ resnet.zero_grad()
+ output = resnet(images)
+ loss = criterion(output, targets)
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+
+ train_loss += loss.item()
+ train_loss_mean = train_loss / (batch + 1)
+ training_progress.set_description(f'Train loss: {train_loss_mean:.4f}')
+ writer.add_scalar('Loss/train', loss, train_iters + 1)
+ train_iters += 1
+
+ test_loss = 0
+ test_acc = 0
+ test_progress = tqdm(
+ enumerate(test_loader), desc='Test loss: ', total=num_test_batches
+ )
+
+ resnet.eval()
+ with torch.no_grad():
+ for batch, (images, targets) in test_progress:
+ images, targets = images.to(device), targets.to(device)
+
+ output = resnet(images)
+ loss = criterion(output, targets)
+ _, prediction = output.max(-1)
+
+ test_loss += loss
+ test_loss_mean = test_loss / (batch + 1)
+ test_progress.set_description(f'Test loss: {test_loss_mean:.4f}')
+ test_acc += (prediction == targets).float().mean()
+ test_acc_mean = test_acc / (batch + 1)
+ writer.add_scalar('Loss/test', loss, test_iters + 1)
+ test_iters += 1
+
+ train_loss_mean = train_loss / num_train_batches
+ test_loss_mean = test_loss / num_test_batches
+ test_acc_mean = test_acc / num_test_batches
+ print(f'[{epoch + 1}/{N_EPOCHS}]\t'
+ f'Train loss: {train_loss_mean:.4f}\t'
+ f'Test loss: {test_loss_mean:.4f}\t',
+ f'Test acc: {test_acc_mean:.4f}')
+
+ writer.add_scalar('Acc', test_acc_mean, epoch + 1)
+
+ torch.save({'epoch': epoch,
+ 'resnet_state_dict': resnet.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'train_loss': train_loss_mean, 'test_loss': test_loss_mean,
+ }, os.path.join(CHECKPOINT_PATH, f'{epoch + 1:04d}.pt'))