aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-18 15:31:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-18 15:31:21 +0800
commitb475ecfa28c603010f550b0a8ad9204a5840b65f (patch)
tree4181a6a698e5e03ad5c90d671e3dcd09f3eb98d9 /supervised
parentaab36ef9a62fa00e7b968de28d0a3e6a5698aebd (diff)
Add CIFAR-100
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py22
-rw-r--r--supervised/models.py4
2 files changed, 18 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 15bb716..0217866 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -14,7 +14,7 @@ from torch.backends import cudnn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
-from torchvision.datasets import CIFAR10, ImageNet
+from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
from torchvision.transforms import transforms, InterpolationMode
from libs.datautils import color_distortion, Clip, RandomGaussianBlur
@@ -135,7 +135,8 @@ def init_logging(args):
def prepare_dataset(args):
- if args.dataset == 'cifar10' or args.dataset == 'cifar':
+ if args.dataset == 'cifar10' or args.dataset == 'cifar100' \
+ or args.dataset == 'cifar':
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
args.crop_size,
@@ -151,9 +152,16 @@ def prepare_dataset(args):
transforms.ToTensor()
])
- train_set = CIFAR10(args.dataset_dir, train=True, transform=train_transform,
- download=True)
- test_set = CIFAR10(args.dataset_dir, train=False, transform=test_transform)
+ if args.dataset == 'cifar10' or args.dataset == 'cifar':
+ train_set = CIFAR10(args.dataset_dir, train=True,
+ transform=train_transform, download=True)
+ test_set = CIFAR10(args.dataset_dir, train=False,
+ transform=test_transform)
+ else: # CIFAR-100
+ train_set = CIFAR100(args.dataset_dir, train=True,
+ transform=train_transform, download=True)
+ test_set = CIFAR100(args.dataset_dir, train=False,
+ transform=test_transform)
elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k':
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
@@ -199,7 +207,9 @@ def create_dataloader(args, train_set, test_set):
def init_model(args):
if args.dataset == 'cifar10' or args.dataset == 'cifar':
- model = CIFARResNet50()
+ model = CIFARResNet50(num_classes=10)
+ elif args.dataset == 'cifar100':
+ model = CIFARResNet50(num_classes=100)
elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k':
model = ImageNetResNet50()
else:
diff --git a/supervised/models.py b/supervised/models.py
index f3ba35a..eb75790 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -5,9 +5,9 @@ from torchvision.models.resnet import BasicBlock
class CIFARResNet50(ResNet):
- def __init__(self):
+ def __init__(self, num_classes):
super(CIFARResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)