diff options
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 23272c3..e21aeee 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -3,6 +3,7 @@ import os import random import torch +from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter @@ -28,8 +29,9 @@ def build_parser(): parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', type=str, help="Model descriptor (default: " "'cifar10-resnet50-256-lars-warmup')") - parser.add_argument('--seed', default=0, type=int, - help='Random seed for reproducibility (default: 0)') + parser.add_argument('--seed', default=-1, type=int, + help='Random seed for reproducibility ' + '(-1 for not set seed) (default: -1)') data_group = parser.add_argument_group('Dataset parameters') data_group.add_argument('--dataset_dir', default='dataset', type=str, @@ -84,6 +86,15 @@ def build_parser(): return args +def set_seed(args): + if args.seed == -1 or args.seed is None or args.seed == '': + cudnn.benchmark = True + else: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + + def prepare_dataset(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': train_transform = transforms.Compose([ @@ -311,8 +322,7 @@ def save_checkpoint(args, epoch_log, model, optimizer): if __name__ == '__main__': args = build_parser() - random.seed(args.seed) - torch.manual_seed(args.seed) + set_seed(args) train_set, test_set = prepare_dataset(args) train_loader, test_loader = create_dataloader(args, train_set, test_set) |