aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-07 22:06:37 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-07 22:06:37 +0800
commit6920d0c033b18d4d27cdadce17f4b56347ed2a5a (patch)
tree5dd7b1d5a6ef51cd29b19d14139109139b337135 /simclr
parentaf0b70c013be4baab0bd6b199f0854107aa24513 (diff)
Add a projector to SimCLR
Diffstat (limited to 'simclr')
-rw-r--r--simclr/main.py16
-rw-r--r--simclr/models.py32
2 files changed, 37 insertions, 11 deletions
diff --git a/simclr/main.py b/simclr/main.py
index 0f93716..9701a02 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -42,9 +42,11 @@ def parse_args_and_config():
help='Path to config file (optional)')
# TODO: Add model hyperparams dataclass
- parser.add_argument('--out-dim', default=256, type=int,
+ parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
- parser.add_argument('--temp', default=0.01, type=float,
+ parser.add_argument('--out-dim', default=128, type=int,
+ help='Number of dimension after projection')
+ parser.add_argument('--temp', default=0.5, type=float,
help='Temperature in InfoNCE loss')
dataset_group = parser.add_argument_group('Dataset parameters')
@@ -135,7 +137,8 @@ class SimCLRConfig(BaseConfig):
class SimCLRTrainer(Trainer):
- def __init__(self, out_dim, **kwargs):
+ def __init__(self, hid_dim, out_dim, **kwargs):
+ self.hid_dim = hid_dim
self.out_dim = out_dim
super(SimCLRTrainer, self).__init__(**kwargs)
@@ -199,9 +202,9 @@ class SimCLRTrainer(Trainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
- model = CIFARSimCLRResNet50(self.out_dim)
+ model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
elif dataset in {'imagenet1k', 'imagenet'}:
- model = ImageNetSimCLRResNet50(self.out_dim)
+ model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
@@ -334,7 +337,7 @@ class SimCLRTrainer(Trainer):
if __name__ == '__main__':
args = parse_args_and_config()
config = SimCLRConfig.from_args(args)
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SimCLRTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
@@ -342,6 +345,7 @@ if __name__ == '__main__':
inf_mode=True,
num_iters=args.num_iters,
config=config,
+ hid_dim=args.hid_dim,
out_dim=args.out_dim,
)
diff --git a/simclr/models.py b/simclr/models.py
index 216a993..9996188 100644
--- a/simclr/models.py
+++ b/simclr/models.py
@@ -4,15 +4,22 @@ from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock
+# TODO Make a SimCLR base class
+
class CIFARSimCLRResNet50(ResNet):
- def __init__(self, out_dim):
+ def __init__(self, hid_dim, out_dim):
super(CIFARSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)
+ self.projector = nn.Sequential(
+ nn.Linear(hid_dim, hid_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hid_dim, out_dim),
+ )
- def forward(self, x: Tensor) -> Tensor:
+ def backbone(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
@@ -28,9 +35,24 @@ class CIFARSimCLRResNet50(ResNet):
return x
+ def forward(self, x: Tensor) -> Tensor:
+ h = self.backbone(x)
+ z = self.projector(h)
+ return z
+
class ImageNetSimCLRResNet50(ResNet):
- def __init__(self, out_dim):
+ def __init__(self, hid_dim, out_dim):
super(ImageNetSimCLRResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim
+ )
+ self.projector = nn.Sequential(
+ nn.Linear(hid_dim, hid_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hid_dim, out_dim),
)
+
+ def forward(self, x: Tensor) -> Tensor:
+ h = super(ImageNetSimCLRResNet50, self).forward(x)
+ z = self.projector(h)
+ return z