diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-14 14:54:11 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-14 14:55:00 +0800 |
commit | cc807c751e2c14ef9a88e5c5be00b4eb082e705b (patch) | |
tree | 04af8607e2906df68d3e77edbd2658353fb044a0 /supervised/baseline.py | |
parent | b9d83e80b946437bb8dc0b586488fa756f52d732 (diff) |
Refactor baseline with trainer
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 588 |
1 files changed, 232 insertions, 356 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 0217866..5e1b32e 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,404 +1,280 @@ import sys +from dataclasses import dataclass from pathlib import Path +from typing import Iterable, Callable path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) import argparse import os -import random import torch import yaml -from torch.backends import cudnn -from torch.nn import CrossEntropyLoss -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from torchvision.datasets import CIFAR10, CIFAR100, ImageNet -from torchvision.transforms import transforms, InterpolationMode - -from libs.datautils import color_distortion, Clip, RandomGaussianBlur -from libs.optimizers import LARS -from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR -from libs.utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER -from models import CIFARResNet50, ImageNetResNet50 - - -def build_parser(): - def range_parser(range_string: str): - try: - range_ = tuple(map(float, range_string.split('-'))) - return range_ - except: - raise argparse.ArgumentTypeError("Range must be 'start-end.'") - - def merge_yaml(args): - if args.config: - config = yaml.safe_load(args.config) - delattr(args, 'config') - args_dict = args.__dict__ - for key, value in config.items(): - if isinstance(value, list): - args_dict[key] = tuple(value) - else: - args_dict[key] = value +from torch.utils.data import Dataset +from torchvision.datasets import CIFAR10, CIFAR100 +from torchvision.transforms import transforms +from libs.datautils import Clip +from libs.schedulers import LinearLR +from libs.utils import Trainer, BaseConfig +from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers +from models import CIFARResNet50 + + +def parse_args_and_config(): parser = argparse.ArgumentParser(description='Supervised baseline') - parser.add_argument('--codename', default='cifar10-resnet50-256-lars-warmup', + parser.add_argument('--codename', default='cifar10-resnet50-256-adam-linear', type=str, help="Model descriptor (default: " - "'cifar10-resnet50-256-lars-warmup')") + "'cifar10-resnet50-256-adam-linear')") + parser.add_argument('--log_dir', default='logs', type=str, + help="Path to log directory (default: 'logs')") + parser.add_argument('--checkpoint_dir', default='checkpoints', type=str, + help="Path to checkpoints directory (default: 'checkpoints')") parser.add_argument('--seed', default=-1, type=int, help='Random seed for reproducibility ' '(-1 for not set seed) (default: -1)') + parser.add_argument('--num_iters', default=1000, type=int, + help='Number of iters (epochs) (default: 1000)') parser.add_argument('--config', type=argparse.FileType(mode='r'), help='Path to config file (optional)') - data_group = parser.add_argument_group('Dataset parameters') - data_group.add_argument('--dataset_dir', default='dataset', type=str, - help="Path to dataset directory (default: 'dataset')") - data_group.add_argument('--dataset', default='cifar10', type=str, - help="Name of dataset (default: 'cifar10')") - data_group.add_argument('--crop_size', default=32, type=int, - help='Random crop size after resize (default: 32)') - data_group.add_argument('--crop_scale_range', default='0.8-1', type=range_parser, - help='Random resize scale range (default: 0.8-1)') - data_group.add_argument('--hflip_p', default=0.5, type=float, - help='Random horizontal flip probability (default: 0.5)') - data_group.add_argument('--distort_s', default=0.5, type=float, - help='Distortion strength (default: 0.5)') - data_group.add_argument('--gaussian_ker_scale', default=10, type=float, - help='Gaussian kernel scale factor ' - '(equals to img_size / kernel_size) (default: 10)') - data_group.add_argument('--gaussian_sigma_range', default='0.1-2', type=range_parser, - help='Random gaussian blur sigma range (default: 0.1-2)') - data_group.add_argument('--gaussian_p', default=0.5, type=float, - help='Random gaussian blur probability (default: 0.5)') - - train_group = parser.add_argument_group('Training parameters') - train_group.add_argument('--batch_size', default=256, type=int, - help='Batch size (default: 256)') - train_group.add_argument('--restore_epoch', default=0, type=int, - help='Restore epoch, 0 for training from scratch ' - '(default: 0)') - train_group.add_argument('--n_epochs', default=1000, type=int, - help='Number of epochs (default: 1000)') - train_group.add_argument('--warmup_epochs', default=10, type=int, - help='Epochs for warmup ' - '(only for `warmup-anneal` scheduler) (default: 10)') - train_group.add_argument('--n_workers', default=2, type=int, - help='Number of dataloader processes (default: 2)') - train_group.add_argument('--optim', default='lars', type=str, - help="Name of optimizer (default: 'lars')") - train_group.add_argument('--sched', default='warmup-anneal', type=str, - help="Name of scheduler (default: 'warmup-anneal')") - train_group.add_argument('--lr', default=1, type=float, + dataset_group = parser.add_argument_group('Dataset parameters') + dataset_group.add_argument('--dataset_dir', default='dataset', type=str, + help="Path to dataset directory (default: 'dataset')") + dataset_group.add_argument('--dataset', default='cifar10', type=str, + choices=('cifar', 'cifar10, cifar100'), + help="Name of dataset (default: 'cifar10')") + dataset_group.add_argument('--crop_size', default=32, type=int, + help='Random crop size after resize (default: 32)') + dataset_group.add_argument('--crop_scale_range', nargs=2, default=(0.8, 1), type=float, + help='Random resize scale range (default: 0.8 1)', + metavar=('start', 'stop')) + dataset_group.add_argument('--hflip_prob', default=0.5, type=float, + help='Random horizontal flip probability (default: 0.5)') + + dataloader_group = parser.add_argument_group('Dataloader parameters') + dataloader_group.add_argument('--batch_size', default=256, type=int, + help='Batch size (default: 256)') + dataloader_group.add_argument('--num_workers', default=2, type=int, + help='Number of dataloader processes (default: 2)') + + optim_group = parser.add_argument_group('Optimizer parameters') + optim_group.add_argument('--optim', default='adam', type=str, + choices=('adam', 'sgd'), + help="Name of optimizer (default: 'adam')") + optim_group.add_argument('--lr', default=1e-3, type=float, help='Learning rate (default: 1)') - train_group.add_argument('--momentum', default=0.9, type=float, - help='Momentum (default: 0.9') - train_group.add_argument('--weight_decay', default=1e-6, type=float, + optim_group.add_argument('--betas', nargs=2, default=(0.9, 0.999), type=float, + help='Adam betas (default: 0.9 0.999)', metavar=('beta1', 'beta2')) + optim_group.add_argument('--momentum', default=0.9, type=float, + help='SDG momentum (default: 0.9)') + optim_group.add_argument('--weight_decay', default=1e-6, type=float, help='Weight decay (l2 regularization) (default: 1e-6)') - logging_group = parser.add_argument_group('Logging config') - logging_group.add_argument('--log_dir', default='logs', type=str, - help="Path to log directory (default: 'logs')") - logging_group.add_argument('--tensorboard_dir', default='runs', type=str, - help="Path to tensorboard directory (default: 'runs')") - logging_group.add_argument('--checkpoint_dir', default='checkpoints', type=str, - help='Path to checkpoints directory ' - "(default: 'checkpoints')") + sched_group = parser.add_argument_group('Optimizer parameters') + sched_group.add_argument('--sched', default='linear', type=str, + choices=(None, '', 'linear'), + help="Name of scheduler (default: None)") args = parser.parse_args() - merge_yaml(args) - - args.device = 'cuda' if torch.cuda.is_available() else 'cpu' - args.log_root = os.path.join(args.log_dir, args.codename) - args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename) - args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename) + if args.config: + config = yaml.safe_load(args.config) + args.__dict__ |= { + k: tuple(v) if isinstance(v, list) else v + for k, v in config.items() + } + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.codename) + args.log_dir = os.path.join(args.log_dir, args.codename) return args -def set_seed(args): - if args.seed == -1 or args.seed is None or args.seed == '': - cudnn.benchmark = True - else: - random.seed(args.seed) - torch.manual_seed(args.seed) - cudnn.deterministic = True +@dataclass +class SupBaselineConfig(BaseConfig): + @dataclass + class DatasetConfig(BaseConfig.DatasetConfig): + dataset_dir: str + crop_size: int + crop_scale_range: tuple[float, float] + hflip_prob: float + + @dataclass + class OptimConfig(BaseConfig.OptimConfig): + momentum: float | None + betas: tuple[float, float] | None + weight_decay: float + + @dataclass + class SchedConfig(BaseConfig.SchedConfig): + sched: str | None -def init_logging(args): - setup_logging(BATCH_LOGGER, os.path.join(args.log_root, 'batch-log.csv')) - setup_logging(EPOCH_LOGGER, os.path.join(args.log_root, 'epoch-log.csv')) - with open(os.path.join(args.log_root, 'params.yaml'), 'w') as params: - yaml.safe_dump(args.__dict__, params) +class SupBaselineTrainer(Trainer): + def __init__(self, **kwargs): + super(SupBaselineTrainer, self).__init__(**kwargs) + @dataclass + class BatchLogRecord(BaseBatchLogRecord): + lr: float + train_loss: float -def prepare_dataset(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar100' \ - or args.dataset == 'cifar': + @dataclass + class EpochLogRecord(BaseEpochLogRecord): + eval_loss: float + eval_accuracy: float + + @staticmethod + def _prepare_dataset(dataset_config: SupBaselineConfig.DatasetConfig) -> tuple[Dataset, Dataset]: train_transform = transforms.Compose([ transforms.RandomResizedCrop( - args.crop_size, - scale=args.crop_scale_range, - interpolation=InterpolationMode.BICUBIC + dataset_config.crop_size, + scale=dataset_config.crop_scale_range, + interpolation=transforms.InterpolationMode.BICUBIC ), - transforms.RandomHorizontalFlip(args.hflip_p), - color_distortion(args.distort_s), + transforms.RandomHorizontalFlip(dataset_config.hflip_prob), transforms.ToTensor(), - Clip() + Clip(), ]) test_transform = transforms.Compose([ - transforms.ToTensor() + transforms.ToTensor(), ]) - if args.dataset == 'cifar10' or args.dataset == 'cifar': - train_set = CIFAR10(args.dataset_dir, train=True, + if dataset_config.dataset in {'cifar10', 'cifar'}: + train_set = CIFAR10(dataset_config.dataset_dir, train=True, transform=train_transform, download=True) - test_set = CIFAR10(args.dataset_dir, train=False, + test_set = CIFAR10(dataset_config.dataset_dir, train=False, transform=test_transform) - else: # CIFAR-100 - train_set = CIFAR100(args.dataset_dir, train=True, + elif dataset_config.dataset == 'cifar100': + train_set = CIFAR100(dataset_config.dataset_dir, train=True, transform=train_transform, download=True) - test_set = CIFAR100(args.dataset_dir, train=False, + test_set = CIFAR100(dataset_config.dataset_dir, train=False, transform=test_transform) - elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': - train_transform = transforms.Compose([ - transforms.RandomResizedCrop( - args.crop_size, - scale=args.crop_scale_range, - interpolation=InterpolationMode.BICUBIC - ), - transforms.RandomHorizontalFlip(args.hflip_p), - color_distortion(args.distort_s), - transforms.ToTensor(), - RandomGaussianBlur( - kernel_size=args.crop_size // args.gaussian_ker_scale, - sigma_range=args.gaussian_sigma_range, - p=args.gaussian_p - ), - Clip() - ]) - test_transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(args.crop_size), - transforms.ToTensor(), - ]) - - train_set = ImageNet(args.dataset_dir, 'train', transform=train_transform) - test_set = ImageNet(args.dataset_dir, 'val', transform=test_transform) - else: - raise NotImplementedError(f"Dataset '{args.dataset}' is not implemented.") - - return train_set, test_set - - -def create_dataloader(args, train_set, test_set): - train_loader = DataLoader(train_set, batch_size=args.batch_size, - shuffle=True, num_workers=args.n_workers) - test_loader = DataLoader(test_set, batch_size=args.batch_size, - shuffle=False, num_workers=args.n_workers) - - args.num_train_batches = len(train_loader) - args.num_test_batches = len(test_loader) - - return train_loader, test_loader - - -def init_model(args): - if args.dataset == 'cifar10' or args.dataset == 'cifar': - model = CIFARResNet50(num_classes=10) - elif args.dataset == 'cifar100': - model = CIFARResNet50(num_classes=100) - elif args.dataset == 'imagenet1k' or args.dataset == 'imagenet1k': - model = ImageNetResNet50() - else: - raise NotImplementedError(f"Dataset '{args.dataset}' is not implemented.") - - return model - - -def configure_optimizer(args, model): - def exclude_from_wd_and_adaptation(name): - if 'bn' in name: - return True - if args.optim == 'lars' and 'bias' in name: - return True - - param_groups = [ - { - 'params': [p for name, p in model.named_parameters() - if not exclude_from_wd_and_adaptation(name)], - 'weight_decay': args.weight_decay, - 'layer_adaptation': True, - }, - { - 'params': [p for name, p in model.named_parameters() - if exclude_from_wd_and_adaptation(name)], - 'weight_decay': 0., - 'layer_adaptation': False, - }, - ] - if args.optim == 'adam': - optimizer = torch.optim.Adam( - param_groups, - lr=args.lr, - betas=(args.momentum, 0.999) - ) - elif args.optim == 'sdg' or args.optim == 'lars': - optimizer = torch.optim.SGD( - param_groups, - lr=args.lr, - momentum=args.momentum - ) - else: - raise NotImplementedError(f"Optimizer '{args.optim}' is not implemented.") - - return optimizer - - -@training_log(EPOCH_LOGGER) -def load_checkpoint(args, model, optimizer): - checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt') - checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint.pop('model_state_dict')) - optimizer.load_state_dict(checkpoint.pop('optimizer_state_dict')) - - return checkpoint - - -def configure_scheduler(args, optimizer): - n_iters = args.n_epochs * args.num_train_batches - last_iter = args.restore_epoch * args.num_train_batches - 1 - if args.sched == 'warmup-anneal': - scheduler = LinearWarmupAndCosineAnneal( - optimizer, - warm_up=args.warmup_epochs / args.n_epochs, - T_max=n_iters, - last_epoch=last_iter - ) - elif args.sched == 'linear': - scheduler = LinearLR( - optimizer, - num_epochs=n_iters, - last_epoch=last_iter - ) - elif args.sched is None or args.sched == '' or args.sched == 'const': - scheduler = None - else: - raise NotImplementedError(f"Scheduler '{args.sched}' is not implemented.") - - return scheduler - - -def wrap_lars(args, optimizer): - if args.optim == 'lars': - return LARS(optimizer) - else: - return optimizer - - -def train(args, train_loader, model, loss_fn, optimizer): - model.train() - for batch, (images, targets) in enumerate(train_loader): - images, targets = images.to(args.device), targets.to(args.device) - model.zero_grad() - output = model(images) - loss = loss_fn(output, targets) - loss.backward() - optimizer.step() - - yield batch, loss.item() - - -def eval(args, test_loader, model, loss_fn): - model.eval() - with torch.no_grad(): - for batch, (images, targets) in enumerate(test_loader): - images, targets = images.to(args.device), targets.to(args.device) - output = model(images) - loss = loss_fn(output, targets) - prediction = output.argmax(1) - accuracy = (prediction == targets).float().mean() - - yield batch, loss.item(), accuracy.item() - - -@training_log(BATCH_LOGGER) -def batch_logger(args, writer, batch, epoch, loss, lr): - global_batch = epoch * args.num_train_batches + batch - writer.add_scalar('Batch loss/train', loss, global_batch + 1) - writer.add_scalar('Batch lr/train', lr, global_batch + 1) - - return { - 'batch': batch + 1, - 'n_batches': args.num_train_batches, - 'global_batch': global_batch + 1, - 'epoch': epoch + 1, - 'n_epochs': args.n_epochs, - 'train_loss': loss, - 'lr': lr, - } - - -@training_log(EPOCH_LOGGER) -def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): - train_loss_mean = train_loss.mean().item() - test_loss_mean = test_loss.mean().item() - test_accuracy_mean = test_accuracy.mean().item() - writer.add_scalar('Epoch loss/train', train_loss_mean, epoch + 1) - writer.add_scalar('Epoch loss/test', test_loss_mean, epoch + 1) - writer.add_scalar('Accuracy/test', test_accuracy_mean, epoch + 1) - - return { - 'epoch': epoch + 1, - 'n_epochs': args.n_epochs, - 'train_loss': train_loss_mean, - 'test_loss': test_loss_mean, - 'test_accuracy': test_accuracy_mean - } - - -def save_checkpoint(args, epoch_log, model, optimizer): - os.makedirs(args.checkpoint_root, exist_ok=True) - - torch.save(epoch_log | { - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - }, os.path.join(args.checkpoint_root, f"{epoch_log['epoch']:04d}.pt")) + else: + raise NotImplementedError(f"Unimplemented dataset: '{dataset_config.dataset}") + + return train_set, test_set + + @staticmethod + def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + if dataset in {'cifar10', 'cifar'}: + model = CIFARResNet50(num_classes=10) + elif dataset == 'cifar100': + model = CIFARResNet50(num_classes=100) + else: + raise NotImplementedError(f"Unimplemented dataset: '{dataset}") + + yield 'model', model + + @staticmethod + def _configure_optimizers( + models: Iterable[tuple[str, torch.nn.Module]], + optim_config: SupBaselineConfig.OptimConfig, + ) -> Iterable[tuple[str, torch.optim.Optimizer]]: + for model_name, model in models: + param_groups = [ + { + 'params': [p for name, p in model.named_parameters() + if 'bn' not in name], + 'weight_decay': optim_config.weight_decay, + 'layer_adaptation': True, + }, + { + 'params': [p for name, p in model.named_parameters() + if 'bn' in name], + 'weight_decay': 0., + 'layer_adaptation': False, + }, + ] + if optim_config.optim == 'adam': + optimizer = torch.optim.Adam( + param_groups, + lr=optim_config.lr, + betas=optim_config.betas, + ) + elif optim_config.optim == 'sdg': + optimizer = torch.optim.SGD( + param_groups, + lr=optim_config.lr, + momentum=optim_config.momentum, + ) + else: + raise NotImplementedError(f"Unimplemented optimizer: '{optim_config.optim}'") + + yield f"{model_name}_optim", optimizer + + @staticmethod + def _configure_scheduler( + optims: Iterable[tuple[str, torch.optim.Optimizer]], + last_iter: int, + num_iters: int, + sched_config: SupBaselineConfig.SchedConfig + ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] + | tuple[str, None]]: + for optim_name, optim in optims: + if sched_config.sched == 'linear': + sched = LinearLR(optim, num_iters, last_epoch=last_iter) + elif sched_config.sched is None: + sched = None + else: + raise NotImplementedError(f"Unimplemented scheduler: {sched_config.sched}") + yield f"{optim_name}_sched", sched + + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + model = self.models['model'] + optim = self.optims['model_optim'] + sched = self.scheds['model_optim_sched'] + loader_size = len(self.train_loader) + num_batches = num_iters * loader_size + for iter_ in range(self.restore_iter, num_iters): + model.train() + for batch, (images, targets) in enumerate(self.train_loader): + global_batch = iter_ * loader_size + batch + images, targets = images.to(device), targets.to(device) + model.zero_grad() + output = model(images) + train_loss = loss_fn(output, targets) + train_loss.backward() + optim.step() + self.log(logger, self.BatchLogRecord( + batch, num_batches, global_batch, iter_, num_iters, + optim.param_groups[0]['lr'], train_loss.item() + )) + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + eval_loss = metrics[0].item() + eval_accuracy = metrics[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + # Step after save checkpoint, otherwise the schedular will one iter ahead after restore + if sched is not None: + sched.step() + + def eval(self, loss_fn: Callable, device: torch.device): + model = self.models['model'] + model.eval() + with torch.no_grad(): + for batch, (images, targets) in enumerate(self.test_loader): + images, targets = images.to(device), targets.to(device) + output = model(images) + loss = loss_fn(output, targets) + prediction = output.argmax(1) + accuracy = (prediction == targets).float().mean() + yield loss.item(), accuracy.item() if __name__ == '__main__': - args = build_parser() - set_seed(args) - init_logging(args) - - train_set, test_set = prepare_dataset(args) - train_loader, test_loader = create_dataloader(args, train_set, test_set) - resnet = init_model(args).to(args.device) - xent = CrossEntropyLoss() - optimizer = configure_optimizer(args, resnet) - if args.restore_epoch > 0: - load_checkpoint(args, resnet, optimizer) - scheduler = configure_scheduler(args, optimizer) - optimizer = wrap_lars(args, optimizer) - writer = SummaryWriter(args.tensorboard_root) - - for epoch in range(args.restore_epoch, args.n_epochs): - train_loss = torch.zeros(args.num_train_batches, device=args.device) - test_loss = torch.zeros(args.num_test_batches, device=args.device) - test_accuracy = torch.zeros(args.num_test_batches, device=args.device) - for batch, loss in train(args, train_loader, resnet, xent, optimizer): - train_loss[batch] = loss - batch_logger(args, writer, batch, epoch, loss, optimizer.param_groups[0]['lr']) - if scheduler and batch != args.num_train_batches - 1: - scheduler.step() - for batch, loss, accuracy in eval(args, test_loader, resnet, xent): - test_loss[batch] = loss - test_accuracy[batch] = accuracy - epoch_log = epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy) - save_checkpoint(args, epoch_log, resnet, optimizer) - # Step after save checkpoint, otherwise the schedular - # will one iter ahead after restore - if scheduler: - scheduler.step() + args = parse_args_and_config() + config = SupBaselineConfig.from_args(args) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + trainer = SupBaselineTrainer( + seed=args.seed, + checkpoint_dir=args.checkpoint_dir, + device=device, + inf_mode=False, + num_iters=args.num_iters, + config=config, + ) + + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) |