diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-08 12:31:26 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-08 12:31:26 +0800 |
commit | da1ccde004f9a264f7ada1397751fffce9336b95 (patch) | |
tree | 9718d372a2d2f25e414791cc824754bee7dffcfe /simclr/evaluate.py | |
parent | 57857a49c7fd95a9eaea208bb565df0e8b793149 (diff) |
Add evaluation script
Diffstat (limited to 'simclr/evaluate.py')
-rw-r--r-- | simclr/evaluate.py | 268 |
1 files changed, 268 insertions, 0 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py new file mode 100644 index 0000000..c029ffe --- /dev/null +++ b/simclr/evaluate.py @@ -0,0 +1,268 @@ +import argparse +import os.path +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Callable + +import sys +import torch +import yaml +from torch.utils.data import Dataset +from torchvision.datasets import CIFAR10, CIFAR100, ImageNet +from torchvision.transforms import transforms + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from libs.optimizers import LARS +from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord +from libs.utils import BaseConfig +from simclr.main import SimCLRTrainer, SimCLRConfig +from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 + + +def parse_args_and_config(): + parser = argparse.ArgumentParser( + description='Contrastive baseline SimCLR (evaluation)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--codename', default='cifar10-simclr-linear', + type=str, help="Model descriptor") + parser.add_argument('--log-dir', default='logs', type=str, + help="Path to log directory") + parser.add_argument('--checkpoint-dir', default='checkpoints', type=str, + help="Path to checkpoints directory") + parser.add_argument('--seed', default=None, type=int, + help='Random seed for reproducibility') + parser.add_argument('--num-iters', default=50, type=int, + help='Number of iters') + parser.add_argument('--config', type=argparse.FileType(mode='r'), + help='Path to config file (optional)') + + # TODO: Add model hyperparams dataclass + parser.add_argument('--hid-dim', default=2048, type=int, + help='Number of dimension of embedding') + parser.add_argument('--out-dim', default=128, type=int, + help='Number of dimension after projection') + parser.add_argument('--pretrained-checkpoint', type=str, + help='Pretrained checkpoint location') + parser.add_argument('--finetune', default=False, + action=argparse.BooleanOptionalAction, + help='Finetune backbone or linear head only') + + dataset_group = parser.add_argument_group('Dataset parameters') + dataset_group.add_argument('--dataset-dir', default='dataset', type=str, + help="Path to dataset directory") + dataset_group.add_argument('--dataset', default='cifar10', type=str, + choices=('cifar10, cifar100', 'imagenet'), + help="Name of dataset") + dataset_group.add_argument('--train-size', default=32, type=int, + help='Resize during training') + dataset_group.add_argument('--test-size', default=32, type=int, + help='Resize during testing') + dataset_group.add_argument('--test-crop-size', default=32, type=int, + help='Center crop size during testing') + dataset_group.add_argument('--hflip-prob', default=0.5, type=float, + help='Random horizontal flip probability') + + dataloader_group = parser.add_argument_group('Dataloader parameters') + dataloader_group.add_argument('--batch-size', default=256, type=int, + help='Batch size') + dataloader_group.add_argument('--num-workers', default=2, type=int, + help='Number of dataloader processes') + + optim_group = parser.add_argument_group('Optimizer parameters') + optim_group.add_argument('--optim', default='sgd', type=str, + choices=('adam', 'sgd', 'lars'), + help="Name of optimizer") + optim_group.add_argument('--lr', default=1e-3, type=float, + help='Learning rate') + optim_group.add_argument('--betas', nargs=2, default=(0.9, 0.999), type=float, + help='Adam betas', metavar=('beta1', 'beta2')) + optim_group.add_argument('--momentum', default=0.9, type=float, + help='SDG momentum') + optim_group.add_argument('--weight-decay', default=0., type=float, + help='Weight decay (l2 regularization)') + + sched_group = parser.add_argument_group('Optimizer parameters') + sched_group.add_argument('--sched', default=None, type=str, + help="Name of scheduler") + sched_group.add_argument('--warmup-iters', default=5, type=int, + help='Epochs for warmup (`warmup-anneal` scheduler only)') + + args = parser.parse_args() + if args.config: + config = yaml.safe_load(args.config) + args.__dict__ |= { + k: tuple(v) if isinstance(v, list) else v + for k, v in config.items() + } + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.codename) + args.log_dir = os.path.join(args.log_dir, args.codename) + + return args + + +@dataclass +class SimCLREvalConfig(SimCLRConfig): + @dataclass + class DatasetConfig(BaseConfig.DatasetConfig): + dataset_dir: str + train_size: int + test_size: int + test_crop_size: int + hflip_prob: float + + +class SimCLREvalTrainer(SimCLRTrainer): + def __init__(self, pretrained_checkpoint, finetune, **kwargs): + self.pretrained_checkpoint = pretrained_checkpoint + self.finetune = finetune + super(SimCLREvalTrainer, self).__init__(**kwargs) + + @dataclass + class BatchLogRecord(BaseBatchLogRecord): + lr: float + train_loss: float + + @dataclass + class EpochLogRecord(BaseEpochLogRecord): + eval_loss: float + eval_accuracy: float + + @staticmethod + def _prepare_dataset(dataset_config: SimCLREvalConfig.DatasetConfig) -> tuple[Dataset, Dataset]: + train_transform = transforms.Compose([ + transforms.RandomHorizontalFlip(dataset_config.hflip_prob), + transforms.Resize(dataset_config.train_size), + transforms.ToTensor(), + ]) + test_transform = transforms.Compose([ + transforms.Resize(dataset_config.test_size), + transforms.CenterCrop(dataset_config.test_crop_size), + transforms.ToTensor(), + ]) + if dataset_config.dataset in {'cifar10', 'cifar100', 'cifar'}: + if dataset_config.dataset in {'cifar10', 'cifar'}: + train_set = CIFAR10(dataset_config.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR10(dataset_config.dataset_dir, train=False, + transform=test_transform) + else: # CIFAR-100 + train_set = CIFAR100(dataset_config.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR100(dataset_config.dataset_dir, train=False, + transform=test_transform) + elif dataset_config.dataset in {'imagenet1k', 'imagenet'}: + train_set = ImageNet(dataset_config.dataset_dir, 'train', + transform=train_transform) + test_set = ImageNet(dataset_config.dataset_dir, 'val', + transform=test_transform) + else: + raise NotImplementedError(f"Unimplemented dataset: '{dataset_config.dataset}") + + return train_set, test_set + + def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + if dataset in {'cifar10', 'cifar100', 'cifar'}: + backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False) + if dataset in {'cifar10', 'cifar'}: + classifier = torch.nn.Linear(self.hid_dim, 10) + else: + classifier = torch.nn.Linear(self.hid_dim, 100) + elif dataset in {'imagenet1k', 'imagenet'}: + backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False) + classifier = torch.nn.Linear(self.hid_dim, 1000) + else: + raise NotImplementedError(f"Unimplemented dataset: '{dataset}") + + yield 'backbone', backbone + yield 'classifier', classifier + + def _custom_init_fn(self, config: SimCLREvalConfig): + self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o + for n, o in self.optims.items()} + if self.restore_iter == 0: + pretrained_checkpoint = torch.load(self.pretrained_checkpoint) + backbone_checkpoint = pretrained_checkpoint['model_state_dict'] + backbone_state_dict = {k: v for k, v in backbone_checkpoint.items() + if k in self.models['backbone'].state_dict()} + self.models['backbone'].load_state_dict(backbone_state_dict) + + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + backbone, classifier = self.models.values() + optim_b, optim_c = self.optims.values() + sched_b, sched_c = self.scheds.values() + loader_size = len(self.train_loader) + num_batches = num_iters * loader_size + for iter_ in range(self.restore_iter, num_iters): + if self.finetune: + backbone.train() + else: + backbone.eval() + classifier.train() + for batch, (images, targets) in enumerate(self.train_loader): + global_batch = iter_ * loader_size + batch + images, targets = images.to(device), targets.to(device) + classifier.zero_grad() + if self.finetune: + backbone.zero_grad() + embedding = backbone(images) + else: + with torch.no_grad(): + embedding = backbone(images) + logits = classifier(embedding) + train_loss = loss_fn(logits, targets) + train_loss.backward() + if self.finetune: + optim_b.step() + optim_c.step() + self.log(logger, self.BatchLogRecord( + batch, num_batches, global_batch, iter_, num_iters, + optim_c.param_groups[0]['lr'], train_loss.item() + )) + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + if sched_b is not None and self.finetune: + sched_b.step() + if sched_c is not None: + sched_c.step() + + def eval(self, loss_fn: Callable, device: torch.device): + backbone, classifier = self.models.values() + backbone.eval() + classifier.eval() + with torch.no_grad(): + for images, targets in self.test_loader: + images, targets = images.to(device), targets.to(device) + embedding = backbone(images) + logits = classifier(embedding) + loss = loss_fn(logits, targets) + prediction = logits.argmax(1) + accuracy = (prediction == targets).float().mean() + yield loss.item(), accuracy.item() + + +if __name__ == '__main__': + args = parse_args_and_config() + config = SimCLREvalConfig.from_args(args) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + trainer = SimCLREvalTrainer( + seed=args.seed, + checkpoint_dir=args.checkpoint_dir, + device=device, + inf_mode=False, + num_iters=args.num_iters, + config=config, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + pretrained_checkpoint=args.pretrained_checkpoint, + finetune=args.finetune, + ) + + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) |