aboutsummaryrefslogtreecommitdiff
path: root/supervised/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/models.py')
-rw-r--r--supervised/models.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/supervised/models.py b/supervised/models.py
index eb75790..a613937 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -1,6 +1,6 @@
import torch
from torch import nn, Tensor
-from torchvision.models import ResNet
+from torchvision.models import ResNet, VisionTransformer
from torchvision.models.resnet import BasicBlock
@@ -29,6 +29,20 @@ class CIFARResNet50(ResNet):
return x
+class CIFARViTTiny(VisionTransformer):
+ # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams
+ def __init__(self, num_classes):
+ super().__init__(
+ image_size=32,
+ patch_size=4,
+ num_layers=7,
+ num_heads=12,
+ hidden_dim=384,
+ mlp_dim=384,
+ num_classes=num_classes,
+ )
+
+
class ImageNetResNet50(ResNet):
def __init__(self):
super(ImageNetResNet50, self).__init__(