diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 17:49:51 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-16 17:49:51 +0800 |
commit | 5869d0248fa958acd3447e6bffa8761b91e8e921 (patch) | |
tree | 4e2c0744400d9204bdfd23c58bafcf534c2119fb /supervised/models.py | |
parent | 608178533e93dc7e6fac6059fa139233ab046b63 (diff) |
Regular refactoring
Diffstat (limited to 'supervised/models.py')
-rw-r--r-- | supervised/models.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/supervised/models.py b/supervised/models.py new file mode 100644 index 0000000..47a0dcf --- /dev/null +++ b/supervised/models.py @@ -0,0 +1,29 @@ +import torch +from torch import nn, Tensor +from torchvision.models import ResNet +from torchvision.models.resnet import BasicBlock + + +class CIFAR10ResNet50(ResNet): + def __init__(self): + super(CIFAR10ResNet50, self).__init__( + block=BasicBlock, layers=[3, 4, 6, 3], num_classes=10 + ) + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x |