aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 17:02:27 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 17:02:27 +0800
commit95d5e4e82df0de08210088d07d6d89692f1325d1 (patch)
treeed6b6788383ba77ff439afdc6da4490cb467eba9 /supervised/baseline.py
parenta035bbdd0cc5e39954272d13d1e1595db738846a (diff)
Add supervised ViT baseline
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py37
1 files changed, 27 insertions, 10 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 9c18f9f..db93304 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -17,7 +17,7 @@ sys.path.insert(0, path)
from libs.datautils import Clip
from libs.utils import Trainer, BaseConfig
from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers
-from models import CIFARResNet50
+from models import CIFARResNet50, CIFARViTTiny
def parse_args_and_config():
@@ -38,6 +38,11 @@ def parse_args_and_config():
parser.add_argument('--config', type=argparse.FileType(mode='r'),
help='Path to config file (optional)')
+ parser.add_argument('--backbone', default='resnet', type=str,
+ choices=('resnet', 'vit'), help='Backbone network')
+ parser.add_argument('--label-smooth', default=0., type=float,
+ help='Label smoothing in cross entropy')
+
dataset_group = parser.add_argument_group('Dataset parameters')
dataset_group.add_argument('--dataset-dir', default='dataset', type=str,
help="Path to dataset directory")
@@ -70,9 +75,12 @@ def parse_args_and_config():
optim_group.add_argument('--weight-decay', default=1e-6, type=float,
help='Weight decay (l2 regularization)')
- sched_group = parser.add_argument_group('Optimizer parameters')
+ sched_group = parser.add_argument_group('Scheduler parameters')
sched_group.add_argument('--sched', default='linear', type=str,
- choices=(None, '', 'linear'), help="Name of scheduler")
+ choices=('const', None, 'linear', 'warmup-anneal'),
+ help="Name of scheduler")
+ sched_group.add_argument('--warmup-iters', default=5, type=int,
+ help='Epochs for warmup (`warmup-anneal` scheduler only)')
args = parser.parse_args()
if args.config:
@@ -104,7 +112,8 @@ class SupBaselineConfig(BaseConfig):
class SupBaselineTrainer(Trainer):
- def __init__(self, **kwargs):
+ def __init__(self, backbone, **kwargs):
+ self.backbone = backbone
super(SupBaselineTrainer, self).__init__(**kwargs)
@dataclass
@@ -148,15 +157,21 @@ class SupBaselineTrainer(Trainer):
return train_set, test_set
- @staticmethod
- def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
+ def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
if dataset in {'cifar10', 'cifar'}:
- model = CIFARResNet50(num_classes=10)
+ num_classes = 10
elif dataset == 'cifar100':
- model = CIFARResNet50(num_classes=100)
+ num_classes = 100
else:
raise NotImplementedError(f"Unimplemented dataset: '{dataset}")
+ if self.backbone == 'resnet':
+ model = CIFARResNet50(num_classes)
+ elif self.backbone == 'vit':
+ model = CIFARViTTiny(num_classes)
+ else:
+ raise NotImplementedError(f"Unimplemented backbone: '{self.backbone}")
+
yield 'model', model
@staticmethod
@@ -242,7 +257,7 @@ class SupBaselineTrainer(Trainer):
if __name__ == '__main__':
args = parse_args_and_config()
config = SupBaselineConfig.from_args(args)
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SupBaselineTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
@@ -250,7 +265,9 @@ if __name__ == '__main__':
inf_mode=False,
num_iters=args.num_iters,
config=config,
+ backbone=args.backbone,
)
loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device)
+ loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
+ trainer.train(args.num_iters, loss_fn, loggers, device)