aboutsummaryrefslogtreecommitdiff
path: root/supervised/models.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-18 15:31:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-18 15:31:21 +0800
commitb475ecfa28c603010f550b0a8ad9204a5840b65f (patch)
tree4181a6a698e5e03ad5c90d671e3dcd09f3eb98d9 /supervised/models.py
parentaab36ef9a62fa00e7b968de28d0a3e6a5698aebd (diff)
Add CIFAR-100
Diffstat (limited to 'supervised/models.py')
-rw-r--r--supervised/models.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/supervised/models.py b/supervised/models.py
index f3ba35a..eb75790 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -5,9 +5,9 @@ from torchvision.models.resnet import BasicBlock
class CIFARResNet50(ResNet):
- def __init__(self):
+ def __init__(self, num_classes):
super(CIFARResNet50, self).__init__(
- block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)