diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 14:18:51 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 14:18:51 +0800 |
commit | d3cef6cf4d9a8c9afd2875bf76f072826a050f9b (patch) | |
tree | 9370ad70cead422d77586b2b0b1ac8187d1cef66 | |
parent | d40e8a0de05739c6d07f3da0c8c2c367f6875e02 (diff) |
Make setting seed deterministic and not set seed by default
-rw-r--r-- | supervised/baseline.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 23272c3..e21aeee 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -3,6 +3,7 @@ import os import random import torch +from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter @@ -28,8 +29,9 @@ def build_parser(): parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', type=str, help="Model descriptor (default: " "'cifar10-resnet50-256-lars-warmup')") - parser.add_argument('--seed', default=0, type=int, - help='Random seed for reproducibility (default: 0)') + parser.add_argument('--seed', default=-1, type=int, + help='Random seed for reproducibility ' + '(-1 for not set seed) (default: -1)') data_group = parser.add_argument_group('Dataset parameters') data_group.add_argument('--dataset_dir', default='dataset', type=str, @@ -84,6 +86,15 @@ def build_parser(): return args +def set_seed(args): + if args.seed == -1 or args.seed is None or args.seed == '': + cudnn.benchmark = True + else: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + + def prepare_dataset(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': train_transform = transforms.Compose([ @@ -311,8 +322,7 @@ def save_checkpoint(args, epoch_log, model, optimizer): if __name__ == '__main__': args = build_parser() - random.seed(args.seed) - torch.manual_seed(args.seed) + set_seed(args) train_set, test_set = prepare_dataset(args) train_loader, test_loader = create_dataloader(args, train_set, test_set) |