diff options
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 22 | ||||
-rw-r--r-- | supervised/models.py | 4 |
2 files changed, 18 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 15bb716..0217866 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from torch.backends import cudnn from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10, ImageNet +from torchvision.datasets import CIFAR10, CIFAR100, ImageNet from torchvision.transforms import transforms, InterpolationMode from libs.datautils import color_distortion, Clip, RandomGaussianBlur @@ -135,7 +135,8 @@ def init_logging(args): def prepare_dataset(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar': + if args.dataset == 'cifar10' or args.dataset == 'cifar100' \ + or args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop( args.crop_size, @@ -151,9 +152,16 @@ def prepare_dataset(args): transforms.ToTensor() ]) - train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform, - download=True) - test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform) + if args.dataset == 'cifar10' or args.dataset == 'cifar': + train_set = CIFAR10(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR10(args.dataset_dir, train=False, + transform=test_transform) + else: # CIFAR-100 + train_set = CIFAR100(args.dataset_dir, train=True, + transform=train_transform, download=True) + test_set = CIFAR100(args.dataset_dir, train=False, + transform=test_transform) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': train_transform = transforms.Compose([ transforms.RandomResizedCrop( @@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set): def init_model(args): if args.dataset == 'cifar10' or args.dataset == 'cifar': - model = CIFARResNet50() + model = CIFARResNet50(num_classes=10) + elif args.dataset == 'cifar100': + model = CIFARResNet50(num_classes=100) elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': model = ImageNetResNet50() else: diff --git a/supervised/models.py b/supervised/models.py index f3ba35a..eb75790 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -5,9 +5,9 @@ from torchvision.models.resnet import BasicBlock class CIFARResNet50(ResNet): - def __init__(self): + def __init__(self, num_classes): super(CIFARResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) |