From d01affeafed02701f128a06899fb658441475792 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 18 Aug 2022 12:07:30 +0800 Subject: Add SimCLR ViT variant for CIFAR --- simclr/models.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) (limited to 'simclr/models.py') diff --git a/simclr/models.py b/simclr/models.py index 8b270b9..689e0d3 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -1,7 +1,14 @@ +from pathlib import Path + +import sys import torch from torch import nn, Tensor -from torchvision.models import ResNet -from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import BasicBlock, ResNet + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from supervised.models import CIFARViTTiny # TODO Make a SimCLR base class @@ -46,6 +53,27 @@ class CIFARSimCLRResNet50(ResNet): return h +class CIFARSimCLRViTTiny(CIFARViTTiny): + def __init__(self, hid_dim, out_dim=128, pretrain=True): + super().__init__(num_classes=hid_dim) + + self.pretrain = pretrain + if pretrain: + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) + + def forward(self, x: torch.Tensor) -> Tensor: + h = super(CIFARSimCLRViTTiny, self).forward(x) + if self.pretrain: + z = self.projector(h) + return z + else: + return h + + class ImageNetSimCLRResNet50(ResNet): def __init__(self, hid_dim, out_dim=128, pretrain=True): super(ImageNetSimCLRResNet50, self).__init__( -- cgit v1.2.3