aboutsummaryrefslogtreecommitdiff
path: root/simclr/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'simclr/models.py')
-rw-r--r--simclr/models.py32
1 files changed, 30 insertions, 2 deletions
diff --git a/simclr/models.py b/simclr/models.py
index 8b270b9..689e0d3 100644
--- a/simclr/models.py
+++ b/simclr/models.py
@@ -1,7 +1,14 @@
+from pathlib import Path
+
+import sys
import torch
from torch import nn, Tensor
-from torchvision.models import ResNet
-from torchvision.models.resnet import BasicBlock
+from torchvision.models.resnet import BasicBlock, ResNet
+
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
+from supervised.models import CIFARViTTiny
# TODO Make a SimCLR base class
@@ -46,6 +53,27 @@ class CIFARSimCLRResNet50(ResNet):
return h
+class CIFARSimCLRViTTiny(CIFARViTTiny):
+ def __init__(self, hid_dim, out_dim=128, pretrain=True):
+ super().__init__(num_classes=hid_dim)
+
+ self.pretrain = pretrain
+ if pretrain:
+ self.projector = nn.Sequential(
+ nn.Linear(hid_dim, hid_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hid_dim, out_dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> Tensor:
+ h = super(CIFARSimCLRViTTiny, self).forward(x)
+ if self.pretrain:
+ z = self.projector(h)
+ return z
+ else:
+ return h
+
+
class ImageNetSimCLRResNet50(ResNet):
def __init__(self, hid_dim, out_dim=128, pretrain=True):
super(ImageNetSimCLRResNet50, self).__init__(