diff options
Diffstat (limited to 'simclr')
-rw-r--r-- | simclr/evaluate.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py index fff4eb2..8cbd454 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -18,7 +18,7 @@ from libs.optimizers import LARS from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord from libs.utils import BaseConfig from simclr.main import SimCLRTrainer, SimCLRConfig -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 +from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny def parse_args_and_config(): @@ -169,13 +169,21 @@ class SimCLREvalTrainer(SimCLRTrainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: - backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False) + if self.encoder == 'resnet': + backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False) + elif self.encoder == 'vit': + backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False) + else: + raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") if dataset in {'cifar10', 'cifar'}: classifier = torch.nn.Linear(self.hid_dim, 10) else: classifier = torch.nn.Linear(self.hid_dim, 100) elif dataset in {'imagenet1k', 'imagenet'}: - backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False) + if self.encoder == 'resnet': + backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False) + else: + raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") classifier = torch.nn.Linear(self.hid_dim, 1000) else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") |