diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 17:02:27 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-10 17:02:27 +0800 |
commit | 95d5e4e82df0de08210088d07d6d89692f1325d1 (patch) | |
tree | ed6b6788383ba77ff439afdc6da4490cb467eba9 /supervised | |
parent | a035bbdd0cc5e39954272d13d1e1595db738846a (diff) |
Add supervised ViT baseline
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 37 | ||||
-rw-r--r-- | supervised/config-vit.yaml | 29 | ||||
-rw-r--r-- | supervised/config.yaml | 3 | ||||
-rw-r--r-- | supervised/models.py | 16 |
4 files changed, 74 insertions, 11 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 9c18f9f..db93304 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -17,7 +17,7 @@ sys.path.insert(0, path) from libs.datautils import Clip from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers -from models import CIFARResNet50 +from models import CIFARResNet50, CIFARViTTiny def parse_args_and_config(): @@ -38,6 +38,11 @@ def parse_args_and_config(): parser.add_argument('--config', type=argparse.FileType(mode='r'), help='Path to config file (optional)') + parser.add_argument('--backbone', default='resnet', type=str, + choices=('resnet', 'vit'), help='Backbone network') + parser.add_argument('--label-smooth', default=0., type=float, + help='Label smoothing in cross entropy') + dataset_group = parser.add_argument_group('Dataset parameters') dataset_group.add_argument('--dataset-dir', default='dataset', type=str, help="Path to dataset directory") @@ -70,9 +75,12 @@ def parse_args_and_config(): optim_group.add_argument('--weight-decay', default=1e-6, type=float, help='Weight decay (l2 regularization)') - sched_group = parser.add_argument_group('Optimizer parameters') + sched_group = parser.add_argument_group('Scheduler parameters') sched_group.add_argument('--sched', default='linear', type=str, - choices=(None, '', 'linear'), help="Name of scheduler") + choices=('const', None, 'linear', 'warmup-anneal'), + help="Name of scheduler") + sched_group.add_argument('--warmup-iters', default=5, type=int, + help='Epochs for warmup (`warmup-anneal` scheduler only)') args = parser.parse_args() if args.config: @@ -104,7 +112,8 @@ class SupBaselineConfig(BaseConfig): class SupBaselineTrainer(Trainer): - def __init__(self, **kwargs): + def __init__(self, backbone, **kwargs): + self.backbone = backbone super(SupBaselineTrainer, self).__init__(**kwargs) @dataclass @@ -148,15 +157,21 @@ class SupBaselineTrainer(Trainer): return train_set, test_set - @staticmethod - def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar'}: - model = CIFARResNet50(num_classes=10) + num_classes = 10 elif dataset == 'cifar100': - model = CIFARResNet50(num_classes=100) + num_classes = 100 else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") + if self.backbone == 'resnet': + model = CIFARResNet50(num_classes) + elif self.backbone == 'vit': + model = CIFARViTTiny(num_classes) + else: + raise NotImplementedError(f"Unimplemented backbone: '{self.backbone}") + yield 'model', model @staticmethod @@ -242,7 +257,7 @@ class SupBaselineTrainer(Trainer): if __name__ == '__main__': args = parse_args_and_config() config = SupBaselineConfig.from_args(args) - device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SupBaselineTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, @@ -250,7 +265,9 @@ if __name__ == '__main__': inf_mode=False, num_iters=args.num_iters, config=config, + backbone=args.backbone, ) loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) + loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth) + trainer.train(args.num_iters, loss_fn, loggers, device) diff --git a/supervised/config-vit.yaml b/supervised/config-vit.yaml new file mode 100644 index 0000000..5ed6689 --- /dev/null +++ b/supervised/config-vit.yaml @@ -0,0 +1,29 @@ +codename: cifar10-vit-adam-warmup-anneal +seed: -1 +num_iters: 1000 +log_dir: logs +checkpoint_dir: checkpoints + +backbone: vit +label_smooth: 0.1 + +dataset: cifar10 +dataset_dir: dataset +crop_size: 32 +crop_scale_range: + - 0.8 + - 1 +hflip_prob: 0.5 + +batch_size: 256 +num_workers: 2 + +optim: adam +lr: 0.001 +betas: + - 0.9 + - 0.999 +weight_decay: 5.0e-05 + +sched: warmup-anneal +warmup_iters: 5
\ No newline at end of file diff --git a/supervised/config.yaml b/supervised/config.yaml index 91d6dc0..a3cf4f6 100644 --- a/supervised/config.yaml +++ b/supervised/config.yaml @@ -4,6 +4,9 @@ num_iters: 1000 log_dir: logs checkpoint_dir: checkpoints +backbone: resnet +label_smooth: 0 + dataset: cifar10 dataset_dir: dataset crop_size: 32 diff --git a/supervised/models.py b/supervised/models.py index eb75790..a613937 100644 --- a/supervised/models.py +++ b/supervised/models.py @@ -1,6 +1,6 @@ import torch from torch import nn, Tensor -from torchvision.models import ResNet +from torchvision.models import ResNet, VisionTransformer from torchvision.models.resnet import BasicBlock @@ -29,6 +29,20 @@ class CIFARResNet50(ResNet): return x +class CIFARViTTiny(VisionTransformer): + # Hyperparams copied from https://github.com/omihub777/ViT-CIFAR/blob/f5c8f122b4a825bf284bc9b471ec895cc9f847ae/README.md#3-hyperparams + def __init__(self, num_classes): + super().__init__( + image_size=32, + patch_size=4, + num_layers=7, + num_heads=12, + hidden_dim=384, + mlp_dim=384, + num_classes=num_classes, + ) + + class ImageNetResNet50(ResNet): def __init__(self): super(ImageNetResNet50, self).__init__( |