From d3cef6cf4d9a8c9afd2875bf76f072826a050f9b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 17 Mar 2022 14:18:51 +0800 Subject: Make setting seed deterministic and not set seed by default --- supervised/baseline.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index 23272c3..e21aeee 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -3,6 +3,7 @@ import os import random import torch +from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter @@ -28,8 +29,9 @@ def build_parser(): parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', type=str, help="Model descriptor (default: " "'cifar10-resnet50-256-lars-warmup')") - parser.add_argument('--seed', default=0, type=int, - help='Random seed for reproducibility (default: 0)') + parser.add_argument('--seed', default=-1, type=int, + help='Random seed for reproducibility ' + '(-1 for not set seed) (default: -1)') data_group = parser.add_argument_group('Dataset parameters') data_group.add_argument('--dataset_dir', default='dataset', type=str, @@ -84,6 +86,15 @@ def build_parser(): return args +def set_seed(args): + if args.seed == -1 or args.seed is None or args.seed == '': + cudnn.benchmark = True + else: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + + def prepare_dataset(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': train_transform = transforms.Compose([ @@ -311,8 +322,7 @@ def save_checkpoint(args, epoch_log, model, optimizer): if __name__ == '__main__': args = build_parser() - random.seed(args.seed) - torch.manual_seed(args.seed) + set_seed(args) train_set, test_set = prepare_dataset(args) train_loader, test_loader = create_dataloader(args, train_set, test_set) -- cgit v1.2.3