diff options
-rw-r--r-- | simclr/evaluate.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py index 91395d1..f18c417 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -43,6 +43,8 @@ def parse_args_and_config(): parser.add_argument('--encoder', default='resnet', type=str, choices=('resnet', 'vit'), help='Backbone of encoder') + parser.add_argument('--label-smooth', default=0., type=float, + help='Label smoothing in cross entropy') parser.add_argument('--hid-dim', default=2048, type=int, help='Number of dimension of embedding') parser.add_argument('--out-dim', default=128, type=int, @@ -280,4 +282,5 @@ if __name__ == '__main__': ) loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) + loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth) + trainer.train(args.num_iters, loss_fn, loggers, device) |