From b475ecfa28c603010f550b0a8ad9204a5840b65f Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 18 Mar 2022 15:31:08 +0800 Subject: Add CIFAR-100 --- supervised/baseline.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index 15bb716..0217866 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10, ImageNet +from torchvision.datasets import CIFAR10, CIFAR100, ImageNet from torchvision.transforms import transforms, InterpolationMode from libs.datautils import color_distortion, Clip, RandomGaussianBlur @@ -135,7 +135,8 @@ def init_logging(args): def prepare_dataset(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar': + if args.dataset == 'cifar10' or args.dataset == 'cifar100' \ + or args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, @@ -151,9 +152,16 @@ def prepare_dataset(args): transforms.ToTensor() ]) - train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform, - download=True) - test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform) + if args.dataset == 'cifar10' or args.dataset == 'cifar': + train_set = CIFAR10(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR10(args.dataset_dir, train=False, + transform=test_transform) + else: # CIFAR-100 + train_set = CIFAR100(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR100(args.dataset_dir, train=False, + transform=test_transform) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': train_transform = transforms.Compose([ transforms.RandomResizedCrop( @@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set): def init_model(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': - model = CIFARResNet50() + model = CIFARResNet50(num_classes=10) + elif args.dataset == 'cifar100': + model = CIFARResNet50(num_classes=100) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': model = ImageNetResNet50() else: -- cgit v1.2.3