From 95d5e4e82df0de08210088d07d6d89692f1325d1 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Aug 2022 17:02:27 +0800 Subject: Add supervised ViT baseline --- supervised/models.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'supervised/models.py') diff --git a/supervised/models.py b/supervised/models.py index eb75790..a613937 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -1,6 +1,6 @@ import torch from torch import nn, Tensor -from torchvision.models import ResNet +from torchvision.models import ResNet, VisionTransformer from torchvision.models.resnet import BasicBlock @@ -29,6 +29,20 @@ class CIFARResNet50(ResNet): return x +class CIFARViTTiny(VisionTransformer): + # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams + def __init__(self, num_classes): + super().__init__( + image_size=32, + patch_size=4, + num_layers=7, + num_heads=12, + hidden_dim=384, + mlp_dim=384, + num_classes=num_classes, + ) + + class ImageNetResNet50(ResNet): def __init__(self): super(ImageNetResNet50, self).__init__( -- cgit v1.2.3