aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
Diffstat (limited to 'simclr')
-rw-r--r--simclr/evaluate.py18
1 files changed, 15 insertions, 3 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index 5c41b84..8cbd454 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -18,7 +18,7 @@ from libs.optimizers import LARS
from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
from libs.utils import BaseConfig
from simclr.main import SimCLRTrainer, SimCLRConfig
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
+from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
def parse_args_and_config():
@@ -40,6 +40,9 @@ def parse_args_and_config():
help='Path to config file (optional)')
# TODO: Add model hyperparams dataclass
+ parser.add_argument('--encoder', default='resnet', type=str,
+ choices=('resnet', 'vit'),
+ help='Backbone of encoder')
parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
parser.add_argument('--out-dim', default=128, type=int,
@@ -166,13 +169,21 @@ class SimCLREvalTrainer(SimCLRTrainer):
def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar100', 'cifar'}:
- backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ if self.encoder == 'resnet':
+ backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+ elif self.encoder == 'vit':
+ backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
if dataset in {'cifar10', 'cifar'}:
classifier = torch.nn.Linear(self.hid_dim, 10)
else:
classifier = torch.nn.Linear(self.hid_dim, 100)
elif dataset in {'imagenet1k', 'imagenet'}:
- backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ if self.encoder == 'resnet':
+ backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+ else:
+ raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
classifier = torch.nn.Linear(self.hid_dim, 1000)
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
@@ -261,6 +272,7 @@ if __name__ == '__main__':
inf_mode=False,
num_iters=args.num_iters,
config=config,
+ encoder=args.encoder,
hid_dim=args.hid_dim,
out_dim=args.out_dim,
pretrained_checkpoint=args.pretrained_checkpoint,