From 726e59f030e278ba7ab52d5c48c78a9ceeb7dd8d Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Sat, 20 Aug 2022 09:12:01 +0800
Subject: Add encoder option to SimCLR evaluation script

---
 simclr/evaluate.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

(limited to 'simclr')

diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index 5c41b84..8cbd454 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -18,7 +18,7 @@ from libs.optimizers import LARS
 from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
 from libs.utils import BaseConfig
 from simclr.main import SimCLRTrainer, SimCLRConfig
-from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
+from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50, CIFARSimCLRViTTiny
 
 
 def parse_args_and_config():
@@ -40,6 +40,9 @@ def parse_args_and_config():
                         help='Path to config file (optional)')
 
     # TODO: Add model hyperparams dataclass
+    parser.add_argument('--encoder', default='resnet', type=str,
+                        choices=('resnet', 'vit'),
+                        help='Backbone of encoder')
     parser.add_argument('--hid-dim', default=2048, type=int,
                         help='Number of dimension of embedding')
     parser.add_argument('--out-dim', default=128, type=int,
@@ -166,13 +169,21 @@ class SimCLREvalTrainer(SimCLRTrainer):
 
     def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
         if dataset in {'cifar10', 'cifar100', 'cifar'}:
-            backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+            if self.encoder == 'resnet':
+                backbone = CIFARSimCLRResNet50(self.hid_dim, pretrain=False)
+            elif self.encoder == 'vit':
+                backbone = CIFARSimCLRViTTiny(self.hid_dim, pretrain=False)
+            else:
+                raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
             if dataset in {'cifar10', 'cifar'}:
                 classifier = torch.nn.Linear(self.hid_dim, 10)
             else:
                 classifier = torch.nn.Linear(self.hid_dim, 100)
         elif dataset in {'imagenet1k', 'imagenet'}:
-            backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+            if self.encoder == 'resnet':
+                backbone = ImageNetSimCLRResNet50(self.hid_dim, pretrain=False)
+            else:
+                raise NotImplementedError(f"Unimplemented encoder: '{self.encoder}")
             classifier = torch.nn.Linear(self.hid_dim, 1000)
         else:
             raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
@@ -261,6 +272,7 @@ if __name__ == '__main__':
         inf_mode=False,
         num_iters=args.num_iters,
         config=config,
+        encoder=args.encoder,
         hid_dim=args.hid_dim,
         out_dim=args.out_dim,
         pretrained_checkpoint=args.pretrained_checkpoint,
-- 
cgit v1.2.3