aboutsummaryrefslogtreecommitdiff
path: root/simclr/models.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-20 15:35:56 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-20 15:35:56 +0800
commit4c242c1383afb8072ce6d2904f51cdb005eced4c (patch)
treee4a06a39c9175650bab6c492fbad10726fd69467 /simclr/models.py
parent5ccb892ae59cbe183ef91c1648751ff0085cc0da (diff)
Refactor SimCLR base class
Diffstat (limited to 'simclr/models.py')
-rw-r--r--simclr/models.py91
1 files changed, 24 insertions, 67 deletions
diff --git a/simclr/models.py b/simclr/models.py
index 689e0d3..0c3febb 100644
--- a/simclr/models.py
+++ b/simclr/models.py
@@ -1,26 +1,27 @@
from pathlib import Path
import sys
-import torch
-from torch import nn, Tensor
-from torchvision.models.resnet import BasicBlock, ResNet
+from torch import nn
+from torchvision.models.resnet import resnet50
path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)
-from supervised.models import CIFARViTTiny
+from supervised.models import CIFARViTTiny, CIFARResNet50
-# TODO Make a SimCLR base class
-
-class CIFARSimCLRResNet50(ResNet):
- def __init__(self, hid_dim, out_dim=128, pretrain=True):
- super(CIFARSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
- )
+class SimCLRBase(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ hid_dim: int = 2048,
+ out_dim: int = 128,
+ pretrain: bool = True
+ ):
+ super().__init__()
+ self.backbone = backbone
self.pretrain = pretrain
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
- stride=1, padding=1, bias=False)
+
if pretrain:
self.projector = nn.Sequential(
nn.Linear(hid_dim, hid_dim),
@@ -28,23 +29,7 @@ class CIFARSimCLRResNet50(ResNet):
nn.Linear(hid_dim, out_dim),
)
- def backbone(self, x: Tensor) -> Tensor:
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
- def forward(self, x: Tensor) -> Tensor:
+ def forward(self, x):
h = self.backbone(x)
if self.pretrain:
z = self.projector(h)
@@ -53,44 +38,16 @@ class CIFARSimCLRResNet50(ResNet):
return h
-class CIFARSimCLRViTTiny(CIFARViTTiny):
- def __init__(self, hid_dim, out_dim=128, pretrain=True):
- super().__init__(num_classes=hid_dim)
+def cifar_simclr_resnet50(hid_dim, *args, **kwargs):
+ backbone = CIFARResNet50(num_classes=hid_dim)
+ return SimCLRBase(backbone, hid_dim, *args, **kwargs)
- self.pretrain = pretrain
- if pretrain:
- self.projector = nn.Sequential(
- nn.Linear(hid_dim, hid_dim),
- nn.ReLU(inplace=True),
- nn.Linear(hid_dim, out_dim),
- )
- def forward(self, x: torch.Tensor) -> Tensor:
- h = super(CIFARSimCLRViTTiny, self).forward(x)
- if self.pretrain:
- z = self.projector(h)
- return z
- else:
- return h
+def cifar_simclr_vit_tiny(hid_dim, *args, **kwargs):
+ backbone = CIFARViTTiny(num_classes=hid_dim)
+ return SimCLRBase(backbone, hid_dim, *args, **kwargs)
-class ImageNetSimCLRResNet50(ResNet):
- def __init__(self, hid_dim, out_dim=128, pretrain=True):
- super(ImageNetSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
- )
- self.pretrain = pretrain
- if pretrain:
- self.projector = nn.Sequential(
- nn.Linear(hid_dim, hid_dim),
- nn.ReLU(inplace=True),
- nn.Linear(hid_dim, out_dim),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- h = super(ImageNetSimCLRResNet50, self).forward(x)
- if self.pretrain:
- z = self.projector(h)
- return z
- else:
- return h
+def imagenet_simclr_resnet50(hid_dim, *args, **kwargs):
+ backbone = resnet50(num_classes=hid_dim)
+ return SimCLRBase(backbone, hid_dim, *args, **kwargs)