diff options
Diffstat (limited to 'supervised/models.py')
-rw-r--r-- | supervised/models.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/supervised/models.py b/supervised/models.py index eb75790..a613937 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -1,6 +1,6 @@ import torch from torch import nn, Tensor -from torchvision.models import ResNet +from torchvision.models import ResNet, VisionTransformer from torchvision.models.resnet import BasicBlock @@ -29,6 +29,20 @@ class CIFARResNet50(ResNet): return x +class CIFARViTTiny(VisionTransformer): + # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams + def __init__(self, num_classes): + super().__init__( + image_size=32, + patch_size=4, + num_layers=7, + num_heads=12, + hidden_dim=384, + mlp_dim=384, + num_classes=num_classes, + ) + + class ImageNetResNet50(ResNet): def __init__(self): super(ImageNetResNet50, self).__init__( |