aboutsummaryrefslogtreecommitdiff
path: root/simclr/main.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-18 12:07:30 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-18 12:07:30 +0800
commitd01affeafed02701f128a06899fb658441475792 (patch)
tree5c41b549a7595afe0fde70962f348a834e5d0b3e /simclr/main.py
parent1b8f01ce9706905c36c6f11ed9deac8548ad7341 (diff)
Add SimCLR ViT variant for CIFAR
Diffstat (limited to 'simclr/main.py')
-rw-r--r--simclr/main.py24
1 files changed, 18 insertions, 6 deletions
diff --git a/simclr/main.py b/simclr/main.py
index 69e2ab2..b990b6e 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -1,10 +1,10 @@
import argparse
import os
-import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
+import sys
import torch
import yaml
from torch.utils.data import Dataset
@@ -17,10 +17,9 @@ sys.path.insert(0, path)
from libs.criteria import InfoNCELoss
from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform
from libs.optimizers import LARS
-from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, Loggers
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
+from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
def parse_args_and_config():
@@ -43,6 +42,9 @@ def parse_args_and_config():
help='Path to config file (optional)')
# TODO: Add model hyperparams dataclass
+ parser.add_argument('--encoder', default='resnet', type=str,
+ choices=('resnet', 'vit'),
+ help='Backbone of encoder')
parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
parser.add_argument('--out-dim', default=128, type=int,
@@ -134,7 +136,8 @@ class SimCLRConfig(BaseConfig):
class SimCLRTrainer(Trainer):
- def __init__(self, hid_dim, out_dim, **kwargs):
+ def __init__(self, encoder, hid_dim, out_dim, **kwargs):
+ self.encoder = encoder
self.hid_dim = hid_dim
self.out_dim = out_dim
super(SimCLRTrainer, self).__init__(**kwargs)
@@ -199,9 +202,17 @@ class SimCLRTrainer(Trainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
- model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
+ if self.encoder == 'resnet':
+ model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
+ elif self.encoder == 'vit':
+ model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
elif dataset in {'imagenet1k', 'imagenet'}:
- model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
+ if self.encoder == 'resnet':
+ model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
@@ -311,6 +322,7 @@ if __name__ == '__main__':
inf_mode=True,
num_iters=args.num_iters,
config=config,
+ encoder=args.encoder,
hid_dim=args.hid_dim,
out_dim=args.out_dim,
)