aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-21 17:28:07 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-21 17:28:07 +0800
commit49822d3234cb67e4996ad13fdbc3c44e1a0bbf29 (patch)
tree6f6286045cd68de054a602587631a283c64aeb7d /simclr
parent4c242c1383afb8072ce6d2904f51cdb005eced4c (diff)
Some modifications for PosRecon trainer
Diffstat (limited to 'simclr')
-rw-r--r--simclr/evaluate.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index f18c417..1abb5ce 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -18,7 +18,7 @@ from libs.optimizers import LARS
from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
from libs.utils import BaseConfig
from simclr.main import SimCLRTrainer, SimCLRConfig
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
+from simclr.models import cifar_simclr_resnet50, cifar_simclr_vit_tiny, imagenet_simclr_resnet50
def parse_args_and_config():
@@ -172,9 +172,9 @@ class SimCLREvalTrainer(SimCLRTrainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
if self.encoder == 'resnet':
- backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ backbone = cifar_simclr_resnet50(self.hid_dim, pretrain=False)
elif self.encoder == 'vit':
- backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False)
+ backbone = cifar_simclr_vit_tiny(self.hid_dim, pretrain=False)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
if dataset in {'cifar10', 'cifar'}:
@@ -183,7 +183,7 @@ class SimCLREvalTrainer(SimCLRTrainer):
classifier = torch.nn.Linear(self.hid_dim, 100)
elif dataset in {'imagenet1k', 'imagenet'}:
if self.encoder == 'resnet':
- backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ backbone = imagenet_simclr_resnet50(self.hid_dim, pretrain=False)
else:
raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
classifier = torch.nn.Linear(self.hid_dim, 1000)