From b475ecfa28c603010f550b0a8ad9204a5840b65f Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 18 Mar 2022 15:31:08 +0800 Subject: Add CIFAR-100 --- supervised/baseline.py | 22 ++++++++++++++++------ supervised/models.py | 4 ++-- 2 files changed, 18 insertions(+), 8 deletions(-) (limited to 'supervised') diff --git a/supervised/baseline.py b/supervised/baseline.py index 15bb716..0217866 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10, ImageNet +from torchvision.datasets import CIFAR10, CIFAR100, ImageNet from torchvision.transforms import transforms, InterpolationMode from libs.datautils import color_distortion, Clip, RandomGaussianBlur @@ -135,7 +135,8 @@ def init_logging(args): def prepare_dataset(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar': + if args.dataset == 'cifar10' or args.dataset == 'cifar100' \ + or args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, @@ -151,9 +152,16 @@ def prepare_dataset(args): transforms.ToTensor() ]) - train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform, - download=True) - test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform) + if args.dataset == 'cifar10' or args.dataset == 'cifar': + train_set = CIFAR10(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR10(args.dataset_dir, train=False, + transform=test_transform) + else: # CIFAR-100 + train_set = CIFAR100(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR100(args.dataset_dir, train=False, + transform=test_transform) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': train_transform = transforms.Compose([ transforms.RandomResizedCrop( @@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set): def init_model(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': - model = CIFARResNet50() + model = CIFARResNet50(num_classes=10) + elif args.dataset == 'cifar100': + model = CIFARResNet50(num_classes=100) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': model = ImageNetResNet50() else: diff --git a/supervised/models.py b/supervised/models.py index f3ba35a..eb75790 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -5,9 +5,9 @@ from torchvision.models.resnet import BasicBlock class CIFARResNet50(ResNet): - def __init__(self): + def __init__(self, num_classes): super(CIFARResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) -- cgit v1.2.3