diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-18 12:07:30 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-18 12:07:30 +0800 |
commit | d01affeafed02701f128a06899fb658441475792 (patch) | |
tree | 5c41b549a7595afe0fde70962f348a834e5d0b3e /simclr/main.py | |
parent | 1b8f01ce9706905c36c6f11ed9deac8548ad7341 (diff) |
Add SimCLR ViT variant for CIFAR
Diffstat (limited to 'simclr/main.py')
-rw-r--r-- | simclr/main.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/simclr/main.py b/simclr/main.py index 69e2ab2..b990b6e 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -1,10 +1,10 @@ import argparse import os -import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable +import sys import torch import yaml from torch.utils.data import Dataset @@ -17,10 +17,9 @@ sys.path.insert(0, path) from libs.criteria import InfoNCELoss from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform from libs.optimizers import LARS -from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, Loggers -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 +from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny def parse_args_and_config(): @@ -43,6 +42,9 @@ def parse_args_and_config(): help='Path to config file (optional)') # TODO: Add model hyperparams dataclass + parser.add_argument('--encoder', default='resnet', type=str, + choices=('resnet', 'vit'), + help='Backbone of encoder') parser.add_argument('--hid-dim', default=2048, type=int, help='Number of dimension of embedding') parser.add_argument('--out-dim', default=128, type=int, @@ -134,7 +136,8 @@ class SimCLRConfig(BaseConfig): class SimCLRTrainer(Trainer): - def __init__(self, hid_dim, out_dim, **kwargs): + def __init__(self, encoder, hid_dim, out_dim, **kwargs): + self.encoder = encoder self.hid_dim = hid_dim self.out_dim = out_dim super(SimCLRTrainer, self).__init__(**kwargs) @@ -199,9 +202,17 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: - model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + if self.encoder == 'resnet': + model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + elif self.encoder == 'vit': + model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim) + else: + raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") elif dataset in {'imagenet1k', 'imagenet'}: - model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + if self.encoder == 'resnet': + model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + else: + raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") @@ -311,6 +322,7 @@ if __name__ == '__main__': inf_mode=True, num_iters=args.num_iters, config=config, + encoder=args.encoder, hid_dim=args.hid_dim, out_dim=args.out_dim, ) |