aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 15bb716..0217866 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -14,7 +14,7 @@ from torch.backends import cudnn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
-from torchvision.datasets import CIFAR10, ImageNet
+from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
from torchvision.transforms import transforms, InterpolationMode
from libs.datautils import color_distortion, Clip, RandomGaussianBlur
@@ -135,7 +135,8 @@ def init_logging(args):
def prepare_dataset(args):
- if args.dataset == 'cifar10' or args.dataset == 'cifar':
+ if args.dataset == 'cifar10' or args.dataset == 'cifar100' \
+ or args.dataset == 'cifar':
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
args.crop_size,
@@ -151,9 +152,16 @@ def prepare_dataset(args):
transforms.ToTensor()
])
- train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform,
- download=True)
- test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform)
+ if args.dataset == 'cifar10' or args.dataset == 'cifar':
+ train_set = CIFAR10(args.dataset_dir, train=True,
+ transform=train_transform, download=True)
+ test_set = CIFAR10(args.dataset_dir, train=False,
+ transform=test_transform)
+ else: # CIFAR-100
+ train_set = CIFAR100(args.dataset_dir, train=True,
+ transform=train_transform, download=True)
+ test_set = CIFAR100(args.dataset_dir, train=False,
+ transform=test_transform)
elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k':
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
@@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set):
def init_model(args):
if args.dataset == 'cifar10' or args.dataset == 'cifar':
- model = CIFARResNet50()
+ model = CIFARResNet50(num_classes=10)
+ elif args.dataset == 'cifar100':
+ model = CIFARResNet50(num_classes=100)
elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k':
model = ImageNetResNet50()
else: