aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
Diffstat (limited to 'simclr')
-rw-r--r--simclr/config-vit.example.yaml38
-rw-r--r--simclr/config.example.yaml1
-rw-r--r--simclr/main.py24
-rw-r--r--simclr/models.py32
4 files changed, 87 insertions, 8 deletions
diff --git a/simclr/config-vit.example.yaml b/simclr/config-vit.example.yaml
new file mode 100644
index 0000000..6246a19
--- /dev/null
+++ b/simclr/config-vit.example.yaml
@@ -0,0 +1,38 @@
+codename: cifar10-simclr-vit-128-lars-warmup-example
+seed: -1
+num_iters: 19531
+log_dir: logs
+checkpoint_dir: checkpoints
+
+encoder: vit
+hid_dim: 2048
+out_dim: 128
+temp: 0.5
+
+dataset: cifar10
+dataset_dir: dataset
+crop_size: 32
+crop_scale_range:
+ - 0.8
+ - 1
+hflip_prob: 0.5
+distort_strength: 0.5
+#gauss_ker_scale: 10
+#gauss_sigma_range:
+# - 0.1
+# - 2
+#gauss_prob: 0.5
+
+batch_size: 128
+num_workers: 2
+
+optim: lars
+lr: 1
+momentum: 0.9
+#betas:
+# - 0.9
+# - 0.999
+weight_decay: 1.0e-06
+
+sched: warmup-anneal
+warmup_iters: 1953 \ No newline at end of file
diff --git a/simclr/config.example.yaml b/simclr/config.example.yaml
index 20bed04..ea9b9e3 100644
--- a/simclr/config.example.yaml
+++ b/simclr/config.example.yaml
@@ -4,6 +4,7 @@ num_iters: 19531
log_dir: logs
checkpoint_dir: checkpoints
+encoder: resnet
hid_dim: 2048
out_dim: 128
temp: 0.5
diff --git a/simclr/main.py b/simclr/main.py
index 69e2ab2..b990b6e 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -1,10 +1,10 @@
import argparse
import os
-import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
+import sys
import torch
import yaml
from torch.utils.data import Dataset
@@ -17,10 +17,9 @@ sys.path.insert(0, path)
from libs.criteria import InfoNCELoss
from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform
from libs.optimizers import LARS
-from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, Loggers
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
+from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
def parse_args_and_config():
@@ -43,6 +42,9 @@ def parse_args_and_config():
help='Path to config file (optional)')
# TODO: Add model hyperparams dataclass
+ parser.add_argument('--encoder', default='resnet', type=str,
+ choices=('resnet', 'vit'),
+ help='Backbone of encoder')
parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
parser.add_argument('--out-dim', default=128, type=int,
@@ -134,7 +136,8 @@ class SimCLRConfig(BaseConfig):
class SimCLRTrainer(Trainer):
- def __init__(self, hid_dim, out_dim, **kwargs):
+ def __init__(self, encoder, hid_dim, out_dim, **kwargs):
+ self.encoder = encoder
self.hid_dim = hid_dim
self.out_dim = out_dim
super(SimCLRTrainer, self).__init__(**kwargs)
@@ -199,9 +202,17 @@ class SimCLRTrainer(Trainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
- model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
+ if self.encoder == 'resnet':
+ model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
+ elif self.encoder == 'vit':
+ model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
elif dataset in {'imagenet1k', 'imagenet'}:
- model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
+ if self.encoder == 'resnet':
+ model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
@@ -311,6 +322,7 @@ if __name__ == '__main__':
inf_mode=True,
num_iters=args.num_iters,
config=config,
+ encoder=args.encoder,
hid_dim=args.hid_dim,
out_dim=args.out_dim,
)
diff --git a/simclr/models.py b/simclr/models.py
index 8b270b9..689e0d3 100644
--- a/simclr/models.py
+++ b/simclr/models.py
@@ -1,7 +1,14 @@
+from pathlib import Path
+
+import sys
import torch
from torch import nn, Tensor
-from torchvision.models import ResNet
-from torchvision.models.resnet import BasicBlock
+from torchvision.models.resnet import BasicBlock, ResNet
+
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
+from supervised.models import CIFARViTTiny
# TODO Make a SimCLR base class
@@ -46,6 +53,27 @@ class CIFARSimCLRResNet50(ResNet):
return h
+class CIFARSimCLRViTTiny(CIFARViTTiny):
+ def __init__(self, hid_dim, out_dim=128, pretrain=True):
+ super().__init__(num_classes=hid_dim)
+
+ self.pretrain = pretrain
+ if pretrain:
+ self.projector = nn.Sequential(
+ nn.Linear(hid_dim, hid_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hid_dim, out_dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> Tensor:
+ h = super(CIFARSimCLRViTTiny, self).forward(x)
+ if self.pretrain:
+ z = self.projector(h)
+ return z
+ else:
+ return h
+
+
class ImageNetSimCLRResNet50(ResNet):
def __init__(self, hid_dim, out_dim=128, pretrain=True):
super(ImageNetSimCLRResNet50, self).__init__(