aboutsummaryrefslogtreecommitdiff
path: root/supervised/models.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 17:02:27 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 17:02:27 +0800
commit95d5e4e82df0de08210088d07d6d89692f1325d1 (patch)
treeed6b6788383ba77ff439afdc6da4490cb467eba9 /supervised/models.py
parenta035bbdd0cc5e39954272d13d1e1595db738846a (diff)
Add supervised ViT baseline
Diffstat (limited to 'supervised/models.py')
-rw-r--r--supervised/models.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/supervised/models.py b/supervised/models.py
index eb75790..a613937 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -1,6 +1,6 @@
import torch
from torch import nn, Tensor
-from torchvision.models import ResNet
+from torchvision.models import ResNet, VisionTransformer
from torchvision.models.resnet import BasicBlock
@@ -29,6 +29,20 @@ class CIFARResNet50(ResNet):
return x
+class CIFARViTTiny(VisionTransformer):
+ # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams
+ def __init__(self, num_classes):
+ super().__init__(
+ image_size=32,
+ patch_size=4,
+ num_layers=7,
+ num_heads=12,
+ hidden_dim=384,
+ mlp_dim=384,
+ num_classes=num_classes,
+ )
+
+
class ImageNetResNet50(ResNet):
def __init__(self):
super(ImageNetResNet50, self).__init__(