diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-07 22:06:37 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-07 22:06:37 +0800 |
commit | 6920d0c033b18d4d27cdadce17f4b56347ed2a5a (patch) | |
tree | 5dd7b1d5a6ef51cd29b19d14139109139b337135 /simclr | |
parent | af0b70c013be4baab0bd6b199f0854107aa24513 (diff) |
Add a projector to SimCLR
Diffstat (limited to 'simclr')
-rw-r--r-- | simclr/main.py | 16 | ||||
-rw-r--r-- | simclr/models.py | 32 |
2 files changed, 37 insertions, 11 deletions
diff --git a/simclr/main.py b/simclr/main.py index 0f93716..9701a02 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -42,9 +42,11 @@ def parse_args_and_config(): help='Path to config file (optional)') # TODO: Add model hyperparams dataclass - parser.add_argument('--out-dim', default=256, type=int, + parser.add_argument('--hid-dim', default=2048, type=int, help='Number of dimension of embedding') - parser.add_argument('--temp', default=0.01, type=float, + parser.add_argument('--out-dim', default=128, type=int, + help='Number of dimension after projection') + parser.add_argument('--temp', default=0.5, type=float, help='Temperature in InfoNCE loss') dataset_group = parser.add_argument_group('Dataset parameters') @@ -135,7 +137,8 @@ class SimCLRConfig(BaseConfig): class SimCLRTrainer(Trainer): - def __init__(self, out_dim, **kwargs): + def __init__(self, hid_dim, out_dim, **kwargs): + self.hid_dim = hid_dim self.out_dim = out_dim super(SimCLRTrainer, self).__init__(**kwargs) @@ -199,9 +202,9 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: - model = CIFARSimCLRResNet50(self.out_dim) + model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) elif dataset in {'imagenet1k', 'imagenet'}: - model = ImageNetSimCLRResNet50(self.out_dim) + model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") @@ -334,7 +337,7 @@ class SimCLRTrainer(Trainer): if __name__ == '__main__': args = parse_args_and_config() config = SimCLRConfig.from_args(args) - device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SimCLRTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, @@ -342,6 +345,7 @@ if __name__ == '__main__': inf_mode=True, num_iters=args.num_iters, config=config, + hid_dim=args.hid_dim, out_dim=args.out_dim, ) diff --git a/simclr/models.py b/simclr/models.py index 216a993..9996188 100644 --- a/simclr/models.py +++ b/simclr/models.py @@ -4,15 +4,22 @@ from torchvision.models import ResNet from torchvision.models.resnet import BasicBlock +# TODO Make a SimCLR base class + class CIFARSimCLRResNet50(ResNet): - def __init__(self, out_dim): + def __init__(self, hid_dim, out_dim): super(CIFARSimCLRResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), + ) - def forward(self, x: Tensor) -> Tensor: + def backbone(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -28,9 +35,24 @@ class CIFARSimCLRResNet50(ResNet): return x + def forward(self, x: Tensor) -> Tensor: + h = self.backbone(x) + z = self.projector(h) + return z + class ImageNetSimCLRResNet50(ResNet): - def __init__(self, out_dim): + def __init__(self, hid_dim, out_dim): super(ImageNetSimCLRResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=out_dim + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=hid_dim + ) + self.projector = nn.Sequential( + nn.Linear(hid_dim, hid_dim), + nn.ReLU(inplace=True), + nn.Linear(hid_dim, out_dim), ) + + def forward(self, x: Tensor) -> Tensor: + h = super(ImageNetSimCLRResNet50, self).forward(x) + z = self.projector(h) + return z |