diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 17:02:27 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 17:02:27 +0800 |
commit | 95d5e4e82df0de08210088d07d6d89692f1325d1 (patch) | |
tree | ed6b6788383ba77ff439afdc6da4490cb467eba9 /supervised/models.py | |
parent | a035bbdd0cc5e39954272d13d1e1595db738846a (diff) |
Add supervised ViT baseline
Diffstat (limited to 'supervised/models.py')
-rw-r--r-- | supervised/models.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/supervised/models.py b/supervised/models.py index eb75790..a613937 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -1,6 +1,6 @@ import torch from torch import nn, Tensor -from torchvision.models import ResNet +from torchvision.models import ResNet, VisionTransformer from torchvision.models.resnet import BasicBlock @@ -29,6 +29,20 @@ class CIFARResNet50(ResNet): return x +class CIFARViTTiny(VisionTransformer): + # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams + def __init__(self, num_classes): + super().__init__( + image_size=32, + patch_size=4, + num_layers=7, + num_heads=12, + hidden_dim=384, + mlp_dim=384, + num_classes=num_classes, + ) + + class ImageNetResNet50(ResNet): def __init__(self): super(ImageNetResNet50, self).__init__( |