From 4c242c1383afb8072ce6d2904f51cdb005eced4c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 20 Aug 2022 15:35:56 +0800 Subject: Refactor SimCLR base class --- simclr/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'simclr/main.py') diff --git a/simclr/main.py b/simclr/main.py index b990b6e..7a86d7d 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -19,7 +19,7 @@ from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTrans from libs.optimizers import LARS from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, Loggers -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny +from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50 def parse_args_and_config(): @@ -203,14 +203,14 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: if self.encoder == 'resnet': - model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + model = cifar_simclr_resnet50(self.hid_dim, self.out_dim) elif self.encoder == 'vit': - model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim) + model = cifar_simclr_vit_tiny(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") elif dataset in {'imagenet1k', 'imagenet'}: if self.encoder == 'resnet': - model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + model = imagenet_simclr_resnet50(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") else: -- cgit v1.2.3