aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--simclr/evaluate.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index 91395d1..f18c417 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -43,6 +43,8 @@ def parse_args_and_config():
parser.add_argument('--encoder', default='resnet', type=str,
choices=('resnet', 'vit'),
help='Backbone of encoder')
+ parser.add_argument('--label-smooth', default=0., type=float,
+ help='Label smoothing in cross entropy')
parser.add_argument('--hid-dim', default=2048, type=int,
help='Number of dimension of embedding')
parser.add_argument('--out-dim', default=128, type=int,
@@ -280,4 +282,5 @@ if __name__ == '__main__':
)
loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device)
+ loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
+ trainer.train(args.num_iters, loss_fn, loggers, device)