diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-18 15:31:08 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-18 15:31:21 +0800 |
commit | b475ecfa28c603010f550b0a8ad9204a5840b65f (patch) | |
tree | 4181a6a698e5e03ad5c90d671e3dcd09f3eb98d9 /supervised/baseline.py | |
parent | aab36ef9a62fa00e7b968de28d0a3e6a5698aebd (diff) |
Add CIFAR-100
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 15bb716..0217866 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10, ImageNet +from torchvision.datasets import CIFAR10, CIFAR100, ImageNet from torchvision.transforms import transforms, InterpolationMode from libs.datautils import color_distortion, Clip, RandomGaussianBlur @@ -135,7 +135,8 @@ def init_logging(args): def prepare_dataset(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar': + if args.dataset == 'cifar10' or args.dataset == 'cifar100' \ + or args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, @@ -151,9 +152,16 @@ def prepare_dataset(args): transforms.ToTensor() ]) - train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform, - download=True) - test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform) + if args.dataset == 'cifar10' or args.dataset == 'cifar': + train_set = CIFAR10(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR10(args.dataset_dir, train=False, + transform=test_transform) + else: # CIFAR-100 + train_set = CIFAR100(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR100(args.dataset_dir, train=False, + transform=test_transform) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': train_transform = transforms.Compose([ transforms.RandomResizedCrop( @@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set): def init_model(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': - model = CIFARResNet50() + model = CIFARResNet50(num_classes=10) + elif args.dataset == 'cifar100': + model = CIFARResNet50(num_classes=100) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': model = ImageNetResNet50() else: |