diff options
Diffstat (limited to 'simclr/models.py')
-rw-r--r-- | simclr/models.py | 42 |
1 files changed, 26 insertions, 16 deletions
diff --git a/simclr/models.py b/simclr/models.py index 9996188..8b270b9 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -7,17 +7,19 @@ from torchvision.models.resnet import BasicBlock # TODO Make a SimCLR base class class CIFARSimCLRResNet50(ResNet): - def __init__(self, hid_dim, out_dim): + def __init__(self, hid_dim, out_dim=128, pretrain=True): super(CIFARSimCLRResNet50, self).__init__( block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim ) + self.pretrain = pretrain self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) - self.projector = nn.Sequential( - nn.Linear(hid_dim, hid_dim), - nn.ReLU(inplace=True), - nn.Linear(hid_dim, out_dim), - ) + if pretrain: + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) def backbone(self, x: Tensor) -> Tensor: x = self.conv1(x) @@ -37,22 +39,30 @@ class CIFARSimCLRResNet50(ResNet): def forward(self, x: Tensor) -> Tensor: h = self.backbone(x) - z = self.projector(h) - return z + if self.pretrain: + z = self.projector(h) + return z + else: + return h class ImageNetSimCLRResNet50(ResNet): - def __init__(self, hid_dim, out_dim): + def __init__(self, hid_dim, out_dim=128, pretrain=True): super(ImageNetSimCLRResNet50, self).__init__( block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim ) - self.projector = nn.Sequential( - nn.Linear(hid_dim, hid_dim), - nn.ReLU(inplace=True), - nn.Linear(hid_dim, out_dim), - ) + self.pretrain = pretrain + if pretrain: + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) def forward(self, x: Tensor) -> Tensor: h = super(ImageNetSimCLRResNet50, self).forward(x) - z = self.projector(h) - return z + if self.pretrain: + z = self.projector(h) + return z + else: + return h |