aboutsummaryrefslogtreecommitdiff
path: root/supervised/models.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-16 19:42:05 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-16 19:42:05 +0800
commit8ddb482b8d3c79009e77bbd15c37f311c6e72aad (patch)
tree4b6967c400e1b1b27011f97f19073892306a048c /supervised/models.py
parent35525c0bc6b85c06dda1e88e1addd9a1cfd5a675 (diff)
Add ImageNet support
Diffstat (limited to 'supervised/models.py')
-rw-r--r--supervised/models.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/supervised/models.py b/supervised/models.py
index 47a0dcf..f3ba35a 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -4,9 +4,9 @@ from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock
-class CIFAR10ResNet50(ResNet):
+class CIFARResNet50(ResNet):
def __init__(self):
- super(CIFAR10ResNet50, self).__init__(
+ super(CIFARResNet50, self).__init__(
block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
@@ -27,3 +27,10 @@ class CIFAR10ResNet50(ResNet):
x = self.fc(x)
return x
+
+
+class ImageNetResNet50(ResNet):
+ def __init__(self):
+ super(ImageNetResNet50, self).__init__(
+ block=BasicBlock, layers=[3, 4, 6, 3], num_classes=1000
+ )