diff options
Diffstat (limited to 'simclr/main.py')
-rw-r--r-- | simclr/main.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/simclr/main.py b/simclr/main.py index b990b6e..7a86d7d 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -19,7 +19,7 @@ from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTrans from libs.optimizers import LARS from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, Loggers -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny +from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50 def parse_args_and_config(): @@ -203,14 +203,14 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: if self.encoder == 'resnet': - model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + model = cifar_simclr_resnet50(self.hid_dim, self.out_dim) elif self.encoder == 'vit': - model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim) + model = cifar_simclr_vit_tiny(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") elif dataset in {'imagenet1k', 'imagenet'}: if self.encoder == 'resnet': - model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + model = imagenet_simclr_resnet50(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") else: |