diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 19:42:05 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 19:42:05 +0800 |
commit | 8ddb482b8d3c79009e77bbd15c37f311c6e72aad (patch) | |
tree | 4b6967c400e1b1b27011f97f19073892306a048c /supervised/models.py | |
parent | 35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (diff) |
Add ImageNet support
Diffstat (limited to 'supervised/models.py')
-rw-r--r-- | supervised/models.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/supervised/models.py b/supervised/models.py index 47a0dcf..f3ba35a 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -4,9 +4,9 @@ from torchvision.models import ResNet from torchvision.models.resnet import BasicBlock -class CIFAR10ResNet50(ResNet): +class CIFARResNet50(ResNet): def __init__(self): - super(CIFAR10ResNet50, self).__init__( + super(CIFARResNet50, self).__init__( block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 ) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, @@ -27,3 +27,10 @@ class CIFAR10ResNet50(ResNet): x = self.fc(x) return x + + +class ImageNetResNet50(ResNet): + def __init__(self): + super(ImageNetResNet50, self).__init__( + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=1000 + ) |