diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-20 15:35:56 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-20 15:35:56 +0800 |
commit | 4c242c1383afb8072ce6d2904f51cdb005eced4c (patch) | |
tree | e4a06a39c9175650bab6c492fbad10726fd69467 /simclr/main.py | |
parent | 5ccb892ae59cbe183ef91c1648751ff0085cc0da (diff) |
Refactor SimCLR base class
Diffstat (limited to 'simclr/main.py')
-rw-r--r-- | simclr/main.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/simclr/main.py b/simclr/main.py index b990b6e..7a86d7d 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -19,7 +19,7 @@ from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTrans from libs.optimizers import LARS from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, Loggers -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny +from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50 def parse_args_and_config(): @@ -203,14 +203,14 @@ class SimCLRTrainer(Trainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: if self.encoder == 'resnet': - model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim) + model = cifar_simclr_resnet50(self.hid_dim, self.out_dim) elif self.encoder == 'vit': - model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim) + model = cifar_simclr_vit_tiny(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") elif dataset in {'imagenet1k', 'imagenet'}: if self.encoder == 'resnet': - model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim) + model = imagenet_simclr_resnet50(self.hid_dim, self.out_dim) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") else: |