From 95d5e4e82df0de08210088d07d6d89692f1325d1 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Aug 2022 17:02:27 +0800 Subject: Add supervised ViT baseline --- supervised/baseline.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) (limited to 'supervised/baseline.py') diff --git a/supervised/baseline.py b/supervised/baseline.py index 9c18f9f..db93304 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -17,7 +17,7 @@ sys.path.insert(0, path) from libs.datautils import Clip from libs.utils import Trainer, BaseConfig from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers -from models import CIFARResNet50 +from models import CIFARResNet50, CIFARViTTiny def parse_args_and_config(): @@ -38,6 +38,11 @@ def parse_args_and_config(): parser.add_argument('--config', type=argparse.FileType(mode='r'), help='Path to config file (optional)') + parser.add_argument('--backbone', default='resnet', type=str, + choices=('resnet', 'vit'), help='Backbone network') + parser.add_argument('--label-smooth', default=0., type=float, + help='Label smoothing in cross entropy') + dataset_group = parser.add_argument_group('Dataset parameters') dataset_group.add_argument('--dataset-dir', default='dataset', type=str, help="Path to dataset directory") @@ -70,9 +75,12 @@ def parse_args_and_config(): optim_group.add_argument('--weight-decay', default=1e-6, type=float, help='Weight decay (l2 regularization)') - sched_group = parser.add_argument_group('Optimizer parameters') + sched_group = parser.add_argument_group('Scheduler parameters') sched_group.add_argument('--sched', default='linear', type=str, - choices=(None, '', 'linear'), help="Name of scheduler") + choices=('const', None, 'linear', 'warmup-anneal'), + help="Name of scheduler") + sched_group.add_argument('--warmup-iters', default=5, type=int, + help='Epochs for warmup (`warmup-anneal` scheduler only)') args = parser.parse_args() if args.config: @@ -104,7 +112,8 @@ class SupBaselineConfig(BaseConfig): class SupBaselineTrainer(Trainer): - def __init__(self, **kwargs): + def __init__(self, backbone, **kwargs): + self.backbone = backbone super(SupBaselineTrainer, self).__init__(**kwargs) @dataclass @@ -148,15 +157,21 @@ class SupBaselineTrainer(Trainer): return train_set, test_set - @staticmethod - def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: if dataset in {'cifar10', 'cifar'}: - model = CIFARResNet50(num_classes=10) + num_classes = 10 elif dataset == 'cifar100': - model = CIFARResNet50(num_classes=100) + num_classes = 100 else: raise NotImplementedError(f"Unimplemented dataset: '{dataset}") + if self.backbone == 'resnet': + model = CIFARResNet50(num_classes) + elif self.backbone == 'vit': + model = CIFARViTTiny(num_classes) + else: + raise NotImplementedError(f"Unimplemented backbone: '{self.backbone}") + yield 'model', model @staticmethod @@ -242,7 +257,7 @@ class SupBaselineTrainer(Trainer): if __name__ == '__main__': args = parse_args_and_config() config = SupBaselineConfig.from_args(args) - device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SupBaselineTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, @@ -250,7 +265,9 @@ if __name__ == '__main__': inf_mode=False, num_iters=args.num_iters, config=config, + backbone=args.backbone, ) loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) + loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth) + trainer.train(args.num_iters, loss_fn, loggers, device) -- cgit v1.2.3