diff options
Diffstat (limited to 'supervised/models.py')
-rw-r--r-- | supervised/models.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/supervised/models.py b/supervised/models.py index f3ba35a..eb75790 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -5,9 +5,9 @@ from torchvision.models.resnet import BasicBlock class CIFARResNet50(ResNet): - def __init__(self): + def __init__(self, num_classes): super(CIFARResNet50, self).__init__( - block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) |