aboutsummaryrefslogtreecommitdiff
path: root/simclr/evaluate.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-20 09:12:01 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-20 09:12:01 +0800
commit9302fa032c107ab4a368a644e1f639daf11e9cc9 (patch)
treef59d5807103f8cb60e4c3f7c186b1a02b5cd2be1 /simclr/evaluate.py
parentea13f7e3fc3aec1b48d61896dcb3032897cc4b7a (diff)
Add encoder option to SimCLR evaluation script
Diffstat (limited to 'simclr/evaluate.py')
-rw-r--r--simclr/evaluate.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index 5c41b84..fff4eb2 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -40,6 +40,9 @@ def parse_args_and_config():
help='Path to config file (optional)')
# TODO: Add model hyperparams dataclass
+ parser.add_argument('--encoder', default='resnet', type=str,
+ choices=('resnet', 'vit'),
+ help='Backbone of encoder')
parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
parser.add_argument('--out-dim', default=128, type=int,
@@ -261,6 +264,7 @@ if __name__ == '__main__':
inf_mode=False,
num_iters=args.num_iters,
config=config,
+ encoder=args.encoder,
hid_dim=args.hid_dim,
out_dim=args.out_dim,
pretrained_checkpoint=args.pretrained_checkpoint,