aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-15 19:45:40 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 12:58:56 +0800
commit2f06ac98982323c3775faba1f5f64b52b5586b70 (patch)
treef28d66061e848b40361fba8afd69f10e967849ec
parent587ccee452d18c44cf33523350b16c79c485d0e7 (diff)
Remove data augmentation on test set
-rw-r--r--supervised/baseline.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 6a83f72..9a83079 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -74,15 +74,20 @@ def get_color_distortion(s=1.0):
return color_distort
-transform = transforms.Compose([
+train_transform = transforms.Compose([
transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(0.5),
get_color_distortion(0.5),
transforms.ToTensor()
])
-train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True)
-test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True)
+test_transform = transforms.Compose([
+ transforms.ToTensor()
+])
+
+train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
+ download=True)
+test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
shuffle=True, num_workers=N_WORKERS)