aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--supervised/baseline.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 6a83f72..9a83079 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -74,15 +74,20 @@ def get_color_distortion(s=1.0):
return color_distort
-transform = transforms.Compose([
+train_transform = transforms.Compose([
transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(0.5),
get_color_distortion(0.5),
transforms.ToTensor()
])
-train_set = CIFAR10(DATASET_ROOT, train=True, transform=transform, download=True)
-test_set = CIFAR10(DATASET_ROOT, train=False, transform=transform, download=True)
+test_transform = transforms.Compose([
+ transforms.ToTensor()
+])
+
+train_set = CIFAR10(DATASET_ROOT, train=True, transform=train_transform,
+ download=True)
+test_set = CIFAR10(DATASET_ROOT, train=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
shuffle=True, num_workers=N_WORKERS)