aboutsummaryrefslogtreecommitdiff
path: root/simclr/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'simclr/main.py')
-rw-r--r--simclr/main.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/simclr/main.py b/simclr/main.py
index b990b6e..7a86d7d 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -19,7 +19,7 @@ from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTrans
from libs.optimizers import LARS
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, Loggers
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
+from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50
def parse_args_and_config():
@@ -203,14 +203,14 @@ class SimCLRTrainer(Trainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
if self.encoder == 'resnet':
- model = CIFARSimCLRResNet50(self.hid_dim, self.out_dim)
+ model = cifar_simclr_resnet50(self.hid_dim, self.out_dim)
elif self.encoder == 'vit':
- model = CIFARSimCLRViTTiny(self.hid_dim, self.out_dim)
+ model = cifar_simclr_vit_tiny(self.hid_dim, self.out_dim)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
elif dataset in {'imagenet1k', 'imagenet'}:
if self.encoder == 'resnet':
- model = ImageNetSimCLRResNet50(self.hid_dim, self.out_dim)
+ model = imagenet_simclr_resnet50(self.hid_dim, self.out_dim)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
else: