diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-18 15:31:08 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-18 15:31:21 +0800 |
commit | b475ecfa28c603010f550b0a8ad9204a5840b65f (patch) | |
tree | 4181a6a698e5e03ad5c90d671e3dcd09f3eb98d9 | |
parent | aab36ef9a62fa00e7b968de28d0a3e6a5698aebd (diff) |
Add CIFAR-100
-rw-r--r-- | readme.md | 2 | ||||
-rw-r--r-- | supervised/baseline.py | 22 | ||||
-rw-r--r-- | supervised/models.py | 4 |
3 files changed, 19 insertions, 9 deletions
@@ -27,7 +27,7 @@ - [x] ResNet - [ ] ViT - [x] CIFAR-10 - - [ ] CIFAR-100 + - [x] CIFAR-100 - [x] ImageNet-1k - Self-supervised baseline diff --git a/supervised/baseline.py b/supervised/baseline.py index 15bb716..0217866 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10, ImageNet +from torchvision.datasets import CIFAR10, CIFAR100, ImageNet from torchvision.transforms import transforms, InterpolationMode from libs.datautils import color_distortion, Clip, RandomGaussianBlur @@ -135,7 +135,8 @@ def init_logging(args): def prepare_dataset(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar': + if args.dataset == 'cifar10' or args.dataset == 'cifar100' \ + or args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, @@ -151,9 +152,16 @@ def prepare_dataset(args): transforms.ToTensor() ]) - train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform, - download=True) - test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform) + if args.dataset == 'cifar10' or args.dataset == 'cifar': + train_set = CIFAR10(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR10(args.dataset_dir, train=False, + transform=test_transform) + else: # CIFAR-100 + train_set = CIFAR100(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR100(args.dataset_dir, train=False, + transform=test_transform) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': train_transform = transforms.Compose([ transforms.RandomResizedCrop( @@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set): def init_model(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': - model = CIFARResNet50() + model = CIFARResNet50(num_classes=10) + elif args.dataset == 'cifar100': + model = CIFARResNet50(num_classes=100) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': model = ImageNetResNet50() else: diff --git a/supervised/models.py b/supervised/models.py index f3ba35a..eb75790 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -5,9 +5,9 @@ from torchvision.models.resnet import BasicBlock class CIFARResNet50(ResNet): - def __init__(self): + def __init__(self, num_classes): super(CIFARResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) |