diff options
Diffstat (limited to 'simclr/models.py')
-rw-r--r-- | simclr/models.py | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/simclr/models.py b/simclr/models.py index 8b270b9..689e0d3 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -1,7 +1,14 @@ +from pathlib import Path + +import sys import torch from torch import nn, Tensor -from torchvision.models import ResNet -from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import BasicBlock, ResNet + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from supervised.models import CIFARViTTiny # TODO Make a SimCLR base class @@ -46,6 +53,27 @@ class CIFARSimCLRResNet50(ResNet): return h +class CIFARSimCLRViTTiny(CIFARViTTiny): + def __init__(self, hid_dim, out_dim=128, pretrain=True): + super().__init__(num_classes=hid_dim) + + self.pretrain = pretrain + if pretrain: + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) + + def forward(self, x: torch.Tensor) -> Tensor: + h = super(CIFARSimCLRViTTiny, self).forward(x) + if self.pretrain: + z = self.projector(h) + return z + else: + return h + + class ImageNetSimCLRResNet50(ResNet): def __init__(self, hid_dim, out_dim=128, pretrain=True): super(ImageNetSimCLRResNet50, self).__init__( |