From 4c242c1383afb8072ce6d2904f51cdb005eced4c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 20 Aug 2022 15:35:56 +0800 Subject: Refactor SimCLR base class --- simclr/main.py | 8 ++--- simclr/models.py | 91 +++++++++++++++----------------------------------------- 2 files changed, 28 insertions(+), 71 deletions(-) diff --git a/simclr/main.py b/simclr/main.py index b990b6e..7a86d7d 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -19,7 +19,7 @@ from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTrans from libs.optimizers import LARS from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, Loggers -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny +from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50 def parse_args_and_config(): @@ -203,14 +203,14 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: if self.encoder == 'resnet': - model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + model = cifar_simclr_resnet50(self.hid_dim, self.out_dim) elif self.encoder == 'vit': - model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim) + model = cifar_simclr_vit_tiny(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") elif dataset in {'imagenet1k', 'imagenet'}: if self.encoder == 'resnet': - model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + model = imagenet_simclr_resnet50(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") else: diff --git a/simclr/models.py b/simclr/models.py index 689e0d3..0c3febb 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -1,26 +1,27 @@ from pathlib import Path import sys -import torch -from torch import nn, Tensor -from torchvision.models.resnet import BasicBlock, ResNet +from torch import nn +from torchvision.models.resnet import resnet50 path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) -from supervised.models import CIFARViTTiny +from supervised.models import CIFARViTTiny, CIFARResNet50 -# TODO Make a SimCLR base class - -class CIFARSimCLRResNet50(ResNet): - def __init__(self, hid_dim, out_dim=128, pretrain=True): - super(CIFARSimCLRResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim - ) +class SimCLRBase(nn.Module): + def __init__( + self, + backbone: nn.Module, + hid_dim: int = 2048, + out_dim: int = 128, + pretrain: bool = True + ): + super().__init__() + self.backbone = backbone self.pretrain = pretrain - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) + if pretrain: self.projector = nn.Sequential( nn.Linear(hid_dim, hid_dim), @@ -28,23 +29,7 @@ class CIFARSimCLRResNet50(ResNet): nn.Linear(hid_dim, out_dim), ) - def backbone(self, x: Tensor) -> Tensor: - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - return x - - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): h = self.backbone(x) if self.pretrain: z = self.projector(h) @@ -53,44 +38,16 @@ class CIFARSimCLRResNet50(ResNet): return h -class CIFARSimCLRViTTiny(CIFARViTTiny): - def __init__(self, hid_dim, out_dim=128, pretrain=True): - super().__init__(num_classes=hid_dim) +def cifar_simclr_resnet50(hid_dim, *args, **kwargs): + backbone = CIFARResNet50(num_classes=hid_dim) + return SimCLRBase(backbone, hid_dim, *args, **kwargs) - self.pretrain = pretrain - if pretrain: - self.projector = nn.Sequential( - nn.Linear(hid_dim, hid_dim), - nn.ReLU(inplace=True), - nn.Linear(hid_dim, out_dim), - ) - def forward(self, x: torch.Tensor) -> Tensor: - h = super(CIFARSimCLRViTTiny, self).forward(x) - if self.pretrain: - z = self.projector(h) - return z - else: - return h +def cifar_simclr_vit_tiny(hid_dim, *args, **kwargs): + backbone = CIFARViTTiny(num_classes=hid_dim) + return SimCLRBase(backbone, hid_dim, *args, **kwargs) -class ImageNetSimCLRResNet50(ResNet): - def __init__(self, hid_dim, out_dim=128, pretrain=True): - super(ImageNetSimCLRResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim - ) - self.pretrain = pretrain - if pretrain: - self.projector = nn.Sequential( - nn.Linear(hid_dim, hid_dim), - nn.ReLU(inplace=True), - nn.Linear(hid_dim, out_dim), - ) - - def forward(self, x: Tensor) -> Tensor: - h = super(ImageNetSimCLRResNet50, self).forward(x) - if self.pretrain: - z = self.projector(h) - return z - else: - return h +def imagenet_simclr_resnet50(hid_dim, *args, **kwargs): + backbone = resnet50(num_classes=hid_dim) + return SimCLRBase(backbone, hid_dim, *args, **kwargs) -- cgit v1.2.3