diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-07 22:06:37 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-07 22:06:37 +0800 |
commit | 6920d0c033b18d4d27cdadce17f4b56347ed2a5a (patch) | |
tree | 5dd7b1d5a6ef51cd29b19d14139109139b337135 /simclr/models.py | |
parent | af0b70c013be4baab0bd6b199f0854107aa24513 (diff) |
Add a projector to SimCLR
Diffstat (limited to 'simclr/models.py')
-rw-r--r-- | simclr/models.py | 32 |
1 files changed, 27 insertions, 5 deletions
diff --git a/simclr/models.py b/simclr/models.py index 216a993..9996188 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -4,15 +4,22 @@ from torchvision.models import ResNet from torchvision.models.resnet import BasicBlock +# TODO Make a SimCLR base class + class CIFARSimCLRResNet50(ResNet): - def __init__(self, out_dim): + def __init__(self, hid_dim, out_dim): super(CIFARSimCLRResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) - def forward(self, x: Tensor) -> Tensor: + def backbone(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -28,9 +35,24 @@ class CIFARSimCLRResNet50(ResNet): return x + def forward(self, x: Tensor) -> Tensor: + h = self.backbone(x) + z = self.projector(h) + return z + class ImageNetSimCLRResNet50(ResNet): - def __init__(self, out_dim): + def __init__(self, hid_dim, out_dim): super(ImageNetSimCLRResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim + ) + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), ) + + def forward(self, x: Tensor) -> Tensor: + h = super(ImageNetSimCLRResNet50, self).forward(x) + z = self.projector(h) + return z |