aboutsummaryrefslogtreecommitdiff
path: root/simclr/main.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-07 22:06:37 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-07 22:06:37 +0800
commit6920d0c033b18d4d27cdadce17f4b56347ed2a5a (patch)
tree5dd7b1d5a6ef51cd29b19d14139109139b337135 /simclr/main.py
parentaf0b70c013be4baab0bd6b199f0854107aa24513 (diff)
Add a projector to SimCLR
Diffstat (limited to 'simclr/main.py')
-rw-r--r--simclr/main.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/simclr/main.py b/simclr/main.py
index 0f93716..9701a02 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -42,9 +42,11 @@ def parse_args_and_config():
help='Path to config file (optional)')
# TODO: Add model hyperparams dataclass
- parser.add_argument('--out-dim', default=256, type=int,
+ parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
- parser.add_argument('--temp', default=0.01, type=float,
+ parser.add_argument('--out-dim', default=128, type=int,
+ help='Number of dimension after projection')
+ parser.add_argument('--temp', default=0.5, type=float,
help='Temperature in InfoNCE loss')
dataset_group = parser.add_argument_group('Dataset parameters')
@@ -135,7 +137,8 @@ class SimCLRConfig(BaseConfig):
class SimCLRTrainer(Trainer):
- def __init__(self, out_dim, **kwargs):
+ def __init__(self, hid_dim, out_dim, **kwargs):
+ self.hid_dim = hid_dim
self.out_dim = out_dim
super(SimCLRTrainer, self).__init__(**kwargs)
@@ -199,9 +202,9 @@ class SimCLRTrainer(Trainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
- model = CIFARSimCLRResNet50(self.out_dim)
+ model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
elif dataset in {'imagenet1k', 'imagenet'}:
- model = ImageNetSimCLRResNet50(self.out_dim)
+ model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
@@ -334,7 +337,7 @@ class SimCLRTrainer(Trainer):
if __name__ == '__main__':
args = parse_args_and_config()
config = SimCLRConfig.from_args(args)
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SimCLRTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
@@ -342,6 +345,7 @@ if __name__ == '__main__':
inf_mode=True,
num_iters=args.num_iters,
config=config,
+ hid_dim=args.hid_dim,
out_dim=args.out_dim,
)