diff options
Diffstat (limited to 'simclr/main.py')
-rw-r--r-- | simclr/main.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/simclr/main.py b/simclr/main.py index 0f93716..9701a02 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -42,9 +42,11 @@ def parse_args_and_config(): help='Path to config file (optional)') # TODO: Add model hyperparams dataclass - parser.add_argument('--out-dim', default=256, type=int, + parser.add_argument('--hid-dim', default=2048, type=int, help='Number of dimension of embedding') - parser.add_argument('--temp', default=0.01, type=float, + parser.add_argument('--out-dim', default=128, type=int, + help='Number of dimension after projection') + parser.add_argument('--temp', default=0.5, type=float, help='Temperature in InfoNCE loss') dataset_group = parser.add_argument_group('Dataset parameters') @@ -135,7 +137,8 @@ class SimCLRConfig(BaseConfig): class SimCLRTrainer(Trainer): - def __init__(self, out_dim, **kwargs): + def __init__(self, hid_dim, out_dim, **kwargs): + self.hid_dim = hid_dim self.out_dim = out_dim super(SimCLRTrainer, self).__init__(**kwargs) @@ -199,9 +202,9 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: - model = CIFARSimCLRResNet50(self.out_dim) + model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) elif dataset in {'imagenet1k', 'imagenet'}: - model = ImageNetSimCLRResNet50(self.out_dim) + model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") @@ -334,7 +337,7 @@ class SimCLRTrainer(Trainer): if __name__ == '__main__': args = parse_args_and_config() config = SimCLRConfig.from_args(args) - device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SimCLRTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, @@ -342,6 +345,7 @@ if __name__ == '__main__': inf_mode=True, num_iters=args.num_iters, config=config, + hid_dim=args.hid_dim, out_dim=args.out_dim, ) |