aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 14:18:51 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 14:18:51 +0800
commitd3cef6cf4d9a8c9afd2875bf76f072826a050f9b (patch)
tree9370ad70cead422d77586b2b0b1ac8187d1cef66
parentd40e8a0de05739c6d07f3da0c8c2c367f6875e02 (diff)
Make setting seed deterministic and not set seed by default
-rw-r--r--supervised/baseline.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 23272c3..e21aeee 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -3,6 +3,7 @@ import os
import random
import torch
+from torch.backends import cudnn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
@@ -28,8 +29,9 @@ def build_parser():
parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup',
type=str, help="Model descriptor (default: "
"'cifar10-resnet50-256-lars-warmup')")
- parser.add_argument('--seed', default=0, type=int,
- help='Random seed for reproducibility (default: 0)')
+ parser.add_argument('--seed', default=-1, type=int,
+ help='Random seed for reproducibility '
+ '(-1 for not set seed) (default: -1)')
data_group = parser.add_argument_group('Dataset parameters')
data_group.add_argument('--dataset_dir', default='dataset', type=str,
@@ -84,6 +86,15 @@ def build_parser():
return args
+def set_seed(args):
+ if args.seed == -1 or args.seed is None or args.seed == '':
+ cudnn.benchmark = True
+ else:
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ cudnn.deterministic = True
+
+
def prepare_dataset(args):
if args.dataset == 'cifar10' or args.dataset == 'cifar':
train_transform = transforms.Compose([
@@ -311,8 +322,7 @@ def save_checkpoint(args, epoch_log, model, optimizer):
if __name__ == '__main__':
args = build_parser()
- random.seed(args.seed)
- torch.manual_seed(args.seed)
+ set_seed(args)
train_set, test_set = prepare_dataset(args)
train_loader, test_loader = create_dataloader(args, train_set, test_set)