diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-20 10:45:04 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-20 10:45:04 +0800 |
commit | 5ccb892ae59cbe183ef91c1648751ff0085cc0da (patch) | |
tree | e06b55270e885f119361559511cac0ab36ff32f9 | |
parent | 7661a84e4ff8c846fd8649b61afbea2cd4c6431e (diff) |
Add label smoothing option
-rw-r--r-- | simclr/evaluate.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py index 91395d1..f18c417 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -43,6 +43,8 @@ def parse_args_and_config(): parser.add_argument('--encoder', default='resnet', type=str, choices=('resnet', 'vit'), help='Backbone of encoder') + parser.add_argument('--label-smooth', default=0., type=float, + help='Label smoothing in cross entropy') parser.add_argument('--hid-dim', default=2048, type=int, help='Number of dimension of embedding') parser.add_argument('--out-dim', default=128, type=int, @@ -280,4 +282,5 @@ if __name__ == '__main__': ) loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) + loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth) + trainer.train(args.num_iters, loss_fn, loggers, device) |