aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
Diffstat (limited to 'simclr')
-rw-r--r--simclr/main.py8
-rw-r--r--simclr/models.py91
2 files changed, 28 insertions, 71 deletions
diff --git a/simclr/main.py b/simclr/main.py
index b990b6e..7a86d7d 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -19,7 +19,7 @@ from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTrans
from libs.optimizers import LARS
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, Loggers
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
+from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50
def parse_args_and_config():
@@ -203,14 +203,14 @@ class SimCLRTrainer(Trainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
if self.encoder == 'resnet':
- model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
+ model = cifar_simclr_resnet50(self.hid_dim, self.out_dim)
elif self.encoder == 'vit':
- model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim)
+ model = cifar_simclr_vit_tiny(self.hid_dim, self.out_dim)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
elif dataset in {'imagenet1k', 'imagenet'}:
if self.encoder == 'resnet':
- model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
+ model = imagenet_simclr_resnet50(self.hid_dim, self.out_dim)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
else:
diff --git a/simclr/models.py b/simclr/models.py
index 689e0d3..0c3febb 100644
--- a/simclr/models.py
+++ b/simclr/models.py
@@ -1,26 +1,27 @@
from pathlib import Path
import sys
-import torch
-from torch import nn, Tensor
-from torchvision.models.resnet import BasicBlock, ResNet
+from torch import nn
+from torchvision.models.resnet import resnet50
path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)
-from supervised.models import CIFARViTTiny
+from supervised.models import CIFARViTTiny, CIFARResNet50
-# TODO Make a SimCLR base class
-
-class CIFARSimCLRResNet50(ResNet):
- def __init__(self, hid_dim, out_dim=128, pretrain=True):
- super(CIFARSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
- )
+class SimCLRBase(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ hid_dim: int = 2048,
+ out_dim: int = 128,
+ pretrain: bool = True
+ ):
+ super().__init__()
+ self.backbone = backbone
self.pretrain = pretrain
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
- stride=1, padding=1, bias=False)
+
if pretrain:
self.projector = nn.Sequential(
nn.Linear(hid_dim, hid_dim),
@@ -28,23 +29,7 @@ class CIFARSimCLRResNet50(ResNet):
nn.Linear(hid_dim, out_dim),
)
- def backbone(self, x: Tensor) -> Tensor:
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
- def forward(self, x: Tensor) -> Tensor:
+ def forward(self, x):
h = self.backbone(x)
if self.pretrain:
z = self.projector(h)
@@ -53,44 +38,16 @@ class CIFARSimCLRResNet50(ResNet):
return h
-class CIFARSimCLRViTTiny(CIFARViTTiny):
- def __init__(self, hid_dim, out_dim=128, pretrain=True):
- super().__init__(num_classes=hid_dim)
+def cifar_simclr_resnet50(hid_dim, *args, **kwargs):
+ backbone = CIFARResNet50(num_classes=hid_dim)
+ return SimCLRBase(backbone, hid_dim, *args, **kwargs)
- self.pretrain = pretrain
- if pretrain:
- self.projector = nn.Sequential(
- nn.Linear(hid_dim, hid_dim),
- nn.ReLU(inplace=True),
- nn.Linear(hid_dim, out_dim),
- )
- def forward(self, x: torch.Tensor) -> Tensor:
- h = super(CIFARSimCLRViTTiny, self).forward(x)
- if self.pretrain:
- z = self.projector(h)
- return z
- else:
- return h
+def cifar_simclr_vit_tiny(hid_dim, *args, **kwargs):
+ backbone = CIFARViTTiny(num_classes=hid_dim)
+ return SimCLRBase(backbone, hid_dim, *args, **kwargs)
-class ImageNetSimCLRResNet50(ResNet):
- def __init__(self, hid_dim, out_dim=128, pretrain=True):
- super(ImageNetSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
- )
- self.pretrain = pretrain
- if pretrain:
- self.projector = nn.Sequential(
- nn.Linear(hid_dim, hid_dim),
- nn.ReLU(inplace=True),
- nn.Linear(hid_dim, out_dim),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- h = super(ImageNetSimCLRResNet50, self).forward(x)
- if self.pretrain:
- z = self.projector(h)
- return z
- else:
- return h
+def imagenet_simclr_resnet50(hid_dim, *args, **kwargs):
+ backbone = resnet50(num_classes=hid_dim)
+ return SimCLRBase(backbone, hid_dim, *args, **kwargs)