diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-15 19:45:40 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 12:58:56 +0800 |
commit | 2f06ac98982323c3775faba1f5f64b52b5586b70 (patch) | |
tree | f28d66061e848b40361fba8afd69f10e967849ec /supervised/baseline.py | |
parent | 587ccee452d18c44cf33523350b16c79c485d0e7 (diff) |
Remove data augmentation on test set
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 6a83f72..9a83079 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -74,15 +74,20 @@ def get_color_distortion(s=1.0): return color_distort -transform = transforms.Compose([ +train_transform = transforms.Compose([ transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(0.5), get_color_distortion(0.5), transforms.ToTensor() ]) -train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True) -test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True) +test_transform = transforms.Compose([ + transforms.ToTensor() +]) + +train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform, + download=True) +test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform) train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_WORKERS) |