aboutsummaryrefslogtreecommitdiff
path: root/simclr/evaluate.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-20 10:45:04 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-20 10:45:04 +0800
commit5ccb892ae59cbe183ef91c1648751ff0085cc0da (patch)
treee06b55270e885f119361559511cac0ab36ff32f9 /simclr/evaluate.py
parent7661a84e4ff8c846fd8649b61afbea2cd4c6431e (diff)
Add label smoothing option
Diffstat (limited to 'simclr/evaluate.py')
-rw-r--r--simclr/evaluate.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index 91395d1..f18c417 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -43,6 +43,8 @@ def parse_args_and_config():
parser.add_argument('--encoder', default='resnet', type=str,
choices=('resnet', 'vit'),
help='Backbone of encoder')
+ parser.add_argument('--label-smooth', default=0., type=float,
+ help='Label smoothing in cross entropy')
parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
parser.add_argument('--out-dim', default=128, type=int,
@@ -280,4 +282,5 @@ if __name__ == '__main__':
)
loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device)
+ loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
+ trainer.train(args.num_iters, loss_fn, loggers, device)