aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
Diffstat (limited to 'simclr')
-rw-r--r--simclr/evaluate.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index f18c417..1abb5ce 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -18,7 +18,7 @@ from libs.optimizers import LARS
from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
from libs.utils import BaseConfig
from simclr.main import SimCLRTrainer, SimCLRConfig
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
+from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50
def parse_args_and_config():
@@ -172,9 +172,9 @@ class SimCLREvalTrainer(SimCLRTrainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
if self.encoder == 'resnet':
- backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ backbone = cifar_simclr_resnet50(self.hid_dim, pretrain=False)
elif self.encoder == 'vit':
- backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False)
+ backbone = cifar_simclr_vit_tiny(self.hid_dim, pretrain=False)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
if dataset in {'cifar10', 'cifar'}:
@@ -183,7 +183,7 @@ class SimCLREvalTrainer(SimCLRTrainer):
classifier = torch.nn.Linear(self.hid_dim, 100)
elif dataset in {'imagenet1k', 'imagenet'}:
if self.encoder == 'resnet':
- backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ backbone = imagenet_simclr_resnet50(self.hid_dim, pretrain=False)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
classifier = torch.nn.Linear(self.hid_dim, 1000)