diff options
Diffstat (limited to 'simclr')
-rw-r--r-- | simclr/config-vit.example.yaml | 38 | ||||
-rw-r--r-- | simclr/config.example.yaml | 1 | ||||
-rw-r--r-- | simclr/main.py | 24 | ||||
-rw-r--r-- | simclr/models.py | 32 |
4 files changed, 87 insertions, 8 deletions
diff --git a/simclr/config-vit.example.yaml b/simclr/config-vit.example.yaml new file mode 100644 index 0000000..6246a19 --- /dev/null +++ b/simclr/config-vit.example.yaml @@ -0,0 +1,38 @@ +codename: cifar10-simclr-vit-128-lars-warmup-example +seed: -1 +num_iters: 19531 +log_dir: logs +checkpoint_dir: checkpoints + +encoder: vit +hid_dim: 2048 +out_dim: 128 +temp: 0.5 + +dataset: cifar10 +dataset_dir: dataset +crop_size: 32 +crop_scale_range: + - 0.8 + - 1 +hflip_prob: 0.5 +distort_strength: 0.5 +#gauss_ker_scale: 10 +#gauss_sigma_range: +# - 0.1 +# - 2 +#gauss_prob: 0.5 + +batch_size: 128 +num_workers: 2 + +optim: lars +lr: 1 +momentum: 0.9 +#betas: +# - 0.9 +# - 0.999 +weight_decay: 1.0e-06 + +sched: warmup-anneal +warmup_iters: 1953
\ No newline at end of file diff --git a/simclr/config.example.yaml b/simclr/config.example.yaml index 20bed04..ea9b9e3 100644 --- a/simclr/config.example.yaml +++ b/simclr/config.example.yaml @@ -4,6 +4,7 @@ num_iters: 19531 log_dir: logs checkpoint_dir: checkpoints +encoder: resnet hid_dim: 2048 out_dim: 128 temp: 0.5 diff --git a/simclr/main.py b/simclr/main.py index 69e2ab2..b990b6e 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -1,10 +1,10 @@ import argparse import os -import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable +import sys import torch import yaml from torch.utils.data import Dataset @@ -17,10 +17,9 @@ sys.path.insert(0, path) from libs.criteria import InfoNCELoss from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform from libs.optimizers import LARS -from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, Loggers -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 +from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny def parse_args_and_config(): @@ -43,6 +42,9 @@ def parse_args_and_config(): help='Path to config file (optional)') # TODO: Add model hyperparams dataclass + parser.add_argument('--encoder', default='resnet', type=str, + choices=('resnet', 'vit'), + help='Backbone of encoder') parser.add_argument('--hid-dim', default=2048, type=int, help='Number of dimension of embedding') parser.add_argument('--out-dim', default=128, type=int, @@ -134,7 +136,8 @@ class SimCLRConfig(BaseConfig): class SimCLRTrainer(Trainer): - def __init__(self, hid_dim, out_dim, **kwargs): + def __init__(self, encoder, hid_dim, out_dim, **kwargs): + self.encoder = encoder self.hid_dim = hid_dim self.out_dim = out_dim super(SimCLRTrainer, self).__init__(**kwargs) @@ -199,9 +202,17 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: - model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + if self.encoder == 'resnet': + model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + elif self.encoder == 'vit': + model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim) + else: + raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") elif dataset in {'imagenet1k', 'imagenet'}: - model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + if self.encoder == 'resnet': + model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + else: + raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") @@ -311,6 +322,7 @@ if __name__ == '__main__': inf_mode=True, num_iters=args.num_iters, config=config, + encoder=args.encoder, hid_dim=args.hid_dim, out_dim=args.out_dim, ) diff --git a/simclr/models.py b/simclr/models.py index 8b270b9..689e0d3 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -1,7 +1,14 @@ +from pathlib import Path + +import sys import torch from torch import nn, Tensor -from torchvision.models import ResNet -from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import BasicBlock, ResNet + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from supervised.models import CIFARViTTiny # TODO Make a SimCLR base class @@ -46,6 +53,27 @@ class CIFARSimCLRResNet50(ResNet): return h +class CIFARSimCLRViTTiny(CIFARViTTiny): + def __init__(self, hid_dim, out_dim=128, pretrain=True): + super().__init__(num_classes=hid_dim) + + self.pretrain = pretrain + if pretrain: + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) + + def forward(self, x: torch.Tensor) -> Tensor: + h = super(CIFARSimCLRViTTiny, self).forward(x) + if self.pretrain: + z = self.projector(h) + return z + else: + return h + + class ImageNetSimCLRResNet50(ResNet): def __init__(self, hid_dim, out_dim=128, pretrain=True): super(ImageNetSimCLRResNet50, self).__init__( |