diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py new file mode 100644 index 0000000..047d8ca --- /dev/null +++ b/supervised/baseline.py @@ -0,0 +1,191 @@ +import os +import random + +import torch +from torch import nn, Tensor, optim +from torch.nn import CrossEntropyLoss +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torchvision.datasets import CIFAR10 +from torchvision.models import ResNet +from torchvision.models.resnet import BasicBlock +from torchvision.transforms import transforms, InterpolationMode +from tqdm import tqdm + +from supervised.lars_optimizer import LARS +from supervised.scheduler import LinearWarmupAndCosineAnneal + +CODENAME = 'cifar10-resnet50-aug-lars-sched' +DATASET_ROOT = '../dataset' +TENSORBOARD_PATH = os.path.join('../runs', CODENAME) +CHECKPOINT_PATH = os.path.join('../checkpoints', CODENAME) + +BATCH_SIZE = 256 +N_EPOCHS = 1000 +WARMUP_EPOCHS = 10 +N_WORKERS = 2 +LR = 1 +MOMENTUM = 0.9 +WEIGHT_DECAY = 1e-6 +SEED = 0 + +if not os.path.exists(CHECKPOINT_PATH): + os.makedirs(CHECKPOINT_PATH) + +random.seed(SEED) +torch.manual_seed(SEED) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class CIFAR10ResNet50(ResNet): + def __init__(self): + super(CIFAR10ResNet50, self).__init__( + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + ) + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +def get_color_distortion(s=1.0): + # s is the strength of color distortion. + color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([ + rnd_color_jitter, + rnd_gray + ]) + return color_distort + + +transform = transforms.Compose([ + transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(0.5), + get_color_distortion(0.5), + transforms.ToTensor() +]) + +train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True) +test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True) + +train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, + shuffle=True, num_workers=N_WORKERS) +test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, + shuffle=False, num_workers=N_WORKERS) + +num_train_batches = len(train_loader) +num_test_batches = len(test_loader) + +resnet = CIFAR10ResNet50().to(device) +criterion = CrossEntropyLoss() + + +def exclude_from_wd_and_adaptation(name): + if 'bn' in name or 'bias' in name: + return True + + +param_groups = [ + { + 'params': [p for name, p in resnet.named_parameters() + if not exclude_from_wd_and_adaptation(name)], + 'weight_decay': WEIGHT_DECAY, + 'layer_adaptation': True, + }, + { + 'params': [p for name, p in resnet.named_parameters() + if exclude_from_wd_and_adaptation(name)], + 'weight_decay': 0., + 'layer_adaptation': False, + }, +] +optimizer = torch.optim.SGD(param_groups, lr=LR, momentum=MOMENTUM) +scheduler = LinearWarmupAndCosineAnneal( + optimizer, + WARMUP_EPOCHS / N_EPOCHS, + N_EPOCHS * num_train_batches, + last_epoch=-1, +) +optimizer = LARS(optimizer) + +writer = SummaryWriter(TENSORBOARD_PATH) + +train_iters = 0 +test_iters = 0 +for epoch in range(N_EPOCHS): + train_loss = 0 + training_progress = tqdm( + enumerate(train_loader), desc='Train loss: ', total=num_train_batches + ) + + resnet.train() + for batch, (images, targets) in training_progress: + images, targets = images.to(device), targets.to(device) + + resnet.zero_grad() + output = resnet(images) + loss = criterion(output, targets) + loss.backward() + optimizer.step() + scheduler.step() + + train_loss += loss.item() + train_loss_mean = train_loss / (batch + 1) + training_progress.set_description(f'Train loss: {train_loss_mean:.4f}') + writer.add_scalar('Loss/train', loss, train_iters + 1) + train_iters += 1 + + test_loss = 0 + test_acc = 0 + test_progress = tqdm( + enumerate(test_loader), desc='Test loss: ', total=num_test_batches + ) + + resnet.eval() + with torch.no_grad(): + for batch, (images, targets) in test_progress: + images, targets = images.to(device), targets.to(device) + + output = resnet(images) + loss = criterion(output, targets) + _, prediction = output.max(-1) + + test_loss += loss + test_loss_mean = test_loss / (batch + 1) + test_progress.set_description(f'Test loss: {test_loss_mean:.4f}') + test_acc += (prediction == targets).float().mean() + test_acc_mean = test_acc / (batch + 1) + writer.add_scalar('Loss/test', loss, test_iters + 1) + test_iters += 1 + + train_loss_mean = train_loss / num_train_batches + test_loss_mean = test_loss / num_test_batches + test_acc_mean = test_acc / num_test_batches + print(f'[{epoch + 1}/{N_EPOCHS}]\t' + f'Train loss: {train_loss_mean:.4f}\t' + f'Test loss: {test_loss_mean:.4f}\t', + f'Test acc: {test_acc_mean:.4f}') + + writer.add_scalar('Acc', test_acc_mean, epoch + 1) + + torch.save({'epoch': epoch, + 'resnet_state_dict': resnet.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': train_loss_mean, 'test_loss': test_loss_mean, + }, os.path.join(CHECKPOINT_PATH, f'{epoch + 1:04d}.pt')) |