aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--simclr/evaluate.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index fff4eb2..8cbd454 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -18,7 +18,7 @@ from libs.optimizers import LARS
from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
from libs.utils import BaseConfig
from simclr.main import SimCLRTrainer, SimCLRConfig
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
+from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
def parse_args_and_config():
@@ -169,13 +169,21 @@ class SimCLREvalTrainer(SimCLRTrainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
- backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ if self.encoder == 'resnet':
+ backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ elif self.encoder == 'vit':
+ backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
if dataset in {'cifar10', 'cifar'}:
classifier = torch.nn.Linear(self.hid_dim, 10)
else:
classifier = torch.nn.Linear(self.hid_dim, 100)
elif dataset in {'imagenet1k', 'imagenet'}:
- backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ if self.encoder == 'resnet':
+ backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
classifier = torch.nn.Linear(self.hid_dim, 1000)
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")