aboutsummaryrefslogtreecommitdiff
path: root/simclr/models.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-07 22:06:37 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-07 22:06:37 +0800
commit6920d0c033b18d4d27cdadce17f4b56347ed2a5a (patch)
tree5dd7b1d5a6ef51cd29b19d14139109139b337135 /simclr/models.py
parentaf0b70c013be4baab0bd6b199f0854107aa24513 (diff)
Add a projector to SimCLR
Diffstat (limited to 'simclr/models.py')
-rw-r--r--simclr/models.py32
1 files changed, 27 insertions, 5 deletions
diff --git a/simclr/models.py b/simclr/models.py
index 216a993..9996188 100644
--- a/simclr/models.py
+++ b/simclr/models.py
@@ -4,15 +4,22 @@ from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock
+# TODO Make a SimCLR base class
+
class CIFARSimCLRResNet50(ResNet):
- def __init__(self, out_dim):
+ def __init__(self, hid_dim, out_dim):
super(CIFARSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)
+ self.projector = nn.Sequential(
+ nn.Linear(hid_dim, hid_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hid_dim, out_dim),
+ )
- def forward(self, x: Tensor) -> Tensor:
+ def backbone(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
@@ -28,9 +35,24 @@ class CIFARSimCLRResNet50(ResNet):
return x
+ def forward(self, x: Tensor) -> Tensor:
+ h = self.backbone(x)
+ z = self.projector(h)
+ return z
+
class ImageNetSimCLRResNet50(ResNet):
- def __init__(self, out_dim):
+ def __init__(self, hid_dim, out_dim):
super(ImageNetSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
+ )
+ self.projector = nn.Sequential(
+ nn.Linear(hid_dim, hid_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hid_dim, out_dim),
)
+
+ def forward(self, x: Tensor) -> Tensor:
+ h = super(ImageNetSimCLRResNet50, self).forward(x)
+ z = self.projector(h)
+ return z