diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-18 12:07:30 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-18 12:07:30 +0800 | 
| commit | d01affeafed02701f128a06899fb658441475792 (patch) | |
| tree | 5c41b549a7595afe0fde70962f348a834e5d0b3e /simclr/models.py | |
| parent | 1b8f01ce9706905c36c6f11ed9deac8548ad7341 (diff) | |
Add SimCLR ViT variant for CIFAR
Diffstat (limited to 'simclr/models.py')
| -rw-r--r-- | simclr/models.py | 32 | 
1 files changed, 30 insertions, 2 deletions
| diff --git a/simclr/models.py b/simclr/models.py index 8b270b9..689e0d3 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -1,7 +1,14 @@ +from pathlib import Path + +import sys  import torch  from torch import nn, Tensor -from torchvision.models import ResNet -from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import BasicBlock, ResNet + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from supervised.models import CIFARViTTiny  # TODO Make a SimCLR base class @@ -46,6 +53,27 @@ class CIFARSimCLRResNet50(ResNet):              return h +class CIFARSimCLRViTTiny(CIFARViTTiny): +    def __init__(self, hid_dim, out_dim=128, pretrain=True): +        super().__init__(num_classes=hid_dim) + +        self.pretrain = pretrain +        if pretrain: +            self.projector = nn.Sequential( +                nn.Linear(hid_dim, hid_dim), +                nn.ReLU(inplace=True), +                nn.Linear(hid_dim, out_dim), +            ) + +    def forward(self, x: torch.Tensor) -> Tensor: +        h = super(CIFARSimCLRViTTiny, self).forward(x) +        if self.pretrain: +            z = self.projector(h) +            return z +        else: +            return h + +  class ImageNetSimCLRResNet50(ResNet):      def __init__(self, hid_dim, out_dim=128, pretrain=True):          super(ImageNetSimCLRResNet50, self).__init__( | 
