From 49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 21 Aug 2022 17:28:07 +0800 Subject: Some modifications for PosRecon trainer --- simclr/evaluate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'simclr') diff --git a/simclr/evaluate.py b/simclr/evaluate.py index f18c417..1abb5ce 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -18,7 +18,7 @@ from libs.optimizers import LARS from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord from libs.utils import BaseConfig from simclr.main import SimCLRTrainer, SimCLRConfig -from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny +from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50 def parse_args_and_config(): @@ -172,9 +172,9 @@ class SimCLREvalTrainer(SimCLRTrainer): def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar100', 'cifar'}: if self.encoder == 'resnet': - backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False) + backbone = cifar_simclr_resnet50(self.hid_dim, pretrain=False) elif self.encoder == 'vit': - backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False) + backbone = cifar_simclr_vit_tiny(self.hid_dim, pretrain=False) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") if dataset in {'cifar10', 'cifar'}: @@ -183,7 +183,7 @@ class SimCLREvalTrainer(SimCLRTrainer): classifier = torch.nn.Linear(self.hid_dim, 100) elif dataset in {'imagenet1k', 'imagenet'}: if self.encoder == 'resnet': - backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False) + backbone = imagenet_simclr_resnet50(self.hid_dim, pretrain=False) else: raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}") classifier = torch.nn.Linear(self.hid_dim, 1000) -- cgit v1.2.3