aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 17:02:27 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 17:02:27 +0800
commit95d5e4e82df0de08210088d07d6d89692f1325d1 (patch)
treeed6b6788383ba77ff439afdc6da4490cb467eba9 /supervised
parenta035bbdd0cc5e39954272d13d1e1595db738846a (diff)
Add supervised ViT baseline
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py37
-rw-r--r--supervised/config-vit.yaml29
-rw-r--r--supervised/config.yaml3
-rw-r--r--supervised/models.py16
4 files changed, 74 insertions, 11 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 9c18f9f..db93304 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -17,7 +17,7 @@ sys.path.insert(0, path)
from libs.datautils import Clip
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers
-from models import CIFARResNet50
+from models import CIFARResNet50, CIFARViTTiny
def parse_args_and_config():
@@ -38,6 +38,11 @@ def parse_args_and_config():
parser.add_argument('--config', type=argparse.FileType(mode='r'),
help='Path to config file (optional)')
+ parser.add_argument('--backbone', default='resnet', type=str,
+ choices=('resnet', 'vit'), help='Backbone network')
+ parser.add_argument('--label-smooth', default=0., type=float,
+ help='Label smoothing in cross entropy')
+
dataset_group = parser.add_argument_group('Dataset parameters')
dataset_group.add_argument('--dataset-dir', default='dataset', type=str,
help="Path to dataset directory")
@@ -70,9 +75,12 @@ def parse_args_and_config():
optim_group.add_argument('--weight-decay', default=1e-6, type=float,
help='Weight decay (l2 regularization)')
- sched_group = parser.add_argument_group('Optimizer parameters')
+ sched_group = parser.add_argument_group('Scheduler parameters')
sched_group.add_argument('--sched', default='linear', type=str,
- choices=(None, '', 'linear'), help="Name of scheduler")
+ choices=('const', None, 'linear', 'warmup-anneal'),
+ help="Name of scheduler")
+ sched_group.add_argument('--warmup-iters', default=5, type=int,
+ help='Epochs for warmup (`warmup-anneal` scheduler only)')
args = parser.parse_args()
if args.config:
@@ -104,7 +112,8 @@ class SupBaselineConfig(BaseConfig):
class SupBaselineTrainer(Trainer):
- def __init__(self, **kwargs):
+ def __init__(self, backbone, **kwargs):
+ self.backbone = backbone
super(SupBaselineTrainer, self).__init__(**kwargs)
@dataclass
@@ -148,15 +157,21 @@ class SupBaselineTrainer(Trainer):
return train_set, test_set
- @staticmethod
- def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
+ def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar'}:
- model = CIFARResNet50(num_classes=10)
+ num_classes = 10
elif dataset == 'cifar100':
- model = CIFARResNet50(num_classes=100)
+ num_classes = 100
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
+ if self.backbone == 'resnet':
+ model = CIFARResNet50(num_classes)
+ elif self.backbone == 'vit':
+ model = CIFARViTTiny(num_classes)
+ else:
+ raise NotImplementedError(f"Unimplemented backbone: '{self.backbone}")
+
yield 'model', model
@staticmethod
@@ -242,7 +257,7 @@ class SupBaselineTrainer(Trainer):
if __name__ == '__main__':
args = parse_args_and_config()
config = SupBaselineConfig.from_args(args)
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SupBaselineTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
@@ -250,7 +265,9 @@ if __name__ == '__main__':
inf_mode=False,
num_iters=args.num_iters,
config=config,
+ backbone=args.backbone,
)
loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device)
+ loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
+ trainer.train(args.num_iters, loss_fn, loggers, device)
diff --git a/supervised/config-vit.yaml b/supervised/config-vit.yaml
new file mode 100644
index 0000000..5ed6689
--- /dev/null
+++ b/supervised/config-vit.yaml
@@ -0,0 +1,29 @@
+codename: cifar10-vit-adam-warmup-anneal
+seed: -1
+num_iters: 1000
+log_dir: logs
+checkpoint_dir: checkpoints
+
+backbone: vit
+label_smooth: 0.1
+
+dataset: cifar10
+dataset_dir: dataset
+crop_size: 32
+crop_scale_range:
+ - 0.8
+ - 1
+hflip_prob: 0.5
+
+batch_size: 256
+num_workers: 2
+
+optim: adam
+lr: 0.001
+betas:
+ - 0.9
+ - 0.999
+weight_decay: 5.0e-05
+
+sched: warmup-anneal
+warmup_iters: 5 \ No newline at end of file
diff --git a/supervised/config.yaml b/supervised/config.yaml
index 91d6dc0..a3cf4f6 100644
--- a/supervised/config.yaml
+++ b/supervised/config.yaml
@@ -4,6 +4,9 @@ num_iters: 1000
log_dir: logs
checkpoint_dir: checkpoints
+backbone: resnet
+label_smooth: 0
+
dataset: cifar10
dataset_dir: dataset
crop_size: 32
diff --git a/supervised/models.py b/supervised/models.py
index eb75790..a613937 100644
--- a/supervised/models.py
+++ b/supervised/models.py
@@ -1,6 +1,6 @@
import torch
from torch import nn, Tensor
-from torchvision.models import ResNet
+from torchvision.models import ResNet, VisionTransformer
from torchvision.models.resnet import BasicBlock
@@ -29,6 +29,20 @@ class CIFARResNet50(ResNet):
return x
+class CIFARViTTiny(VisionTransformer):
+ # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams
+ def __init__(self, num_classes):
+ super().__init__(
+ image_size=32,
+ patch_size=4,
+ num_layers=7,
+ num_heads=12,
+ hidden_dim=384,
+ mlp_dim=384,
+ num_classes=num_classes,
+ )
+
+
class ImageNetResNet50(ResNet):
def __init__(self):
super(ImageNetResNet50, self).__init__(