diff options
Diffstat (limited to 'posrecon')
-rw-r--r-- | posrecon/main.py | 335 |
1 files changed, 317 insertions, 18 deletions
diff --git a/posrecon/main.py b/posrecon/main.py index 0660ea0..c5c946d 100644 --- a/posrecon/main.py +++ b/posrecon/main.py @@ -1,31 +1,330 @@ +from dataclasses import dataclass +from pathlib import Path from typing import Callable, Iterable +import argparse + +import os +import sys import torch -from torch.utils.data import Dataset +import torch.nn.functional as F +import yaml + +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + +from libs.criteria import InfoNCELoss +from libs.logging import Loggers, BaseBatchLogRecord +from libs.optimizers import LARS +from posrecon.models import simclr_pos_recon_vit +from simclr.main import SimCLRTrainer, SimCLRConfig + + +def parse_args_and_config(): + parser = argparse.ArgumentParser( + description='SimCLR w/ positional reconstruction', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--codename', default='cifar10-pos-recon-simclr-128-lars-warmup', + type=str, help="Model descriptor") + parser.add_argument('--log-dir', default='logs', type=str, + help="Path to log directory") + parser.add_argument('--checkpoint-dir', default='checkpoints', type=str, + help="Path to checkpoints directory") + parser.add_argument('--seed', default=None, type=int, + help='Random seed for reproducibility') + parser.add_argument('--num-iters', default=23438, type=int, + help='Number of iters (default is 50 epochs equiv., ' + 'around dataset_size * epochs / batch_size)') + parser.add_argument('--config', type=argparse.FileType(mode='r'), + help='Path to config file (optional)') + + # TODO: Add model hyperparams dataclass + parser.add_argument('--hid-dim', default=2048, type=int, + help='Number of dimension of embedding') + parser.add_argument('--out-dim', default=128, type=int, + help='Number of dimension after projection') + parser.add_argument('--temp', default=0.5, type=float, + help='Temperature in InfoNCE loss') + + parser.add_argument('--pat-sz', default=16, type=int, + help='Size of image patches') + parser.add_argument('--nlayers', default=7, type=int, + help='Depth of Transformer blocks') + parser.add_argument('--nheads', default=12, type=int, + help='Number of attention heads') + parser.add_argument('--embed-dim', default=384, type=int, + help='Number of ViT embedding dimension') + parser.add_argument('--mlp-dim', default=384, type=int, + help='Number of MLP dimension') + parser.add_argument('--dropout', default=0., type=float, + help='MLP dropout rate') + + parser.add_argument('--fixed-pos', default=False, + action=argparse.BooleanOptionalAction, + help='Fixed or learned positional embedding') + parser.add_argument('--shuff-rate', default=0.75, type=float, + help='Ratio of shuffling sequence') + parser.add_argument('--mask-rate', default=0.75, type=float, + help='Ratio of masking sequence') + + dataset_group = parser.add_argument_group('Dataset parameters') + dataset_group.add_argument('--dataset-dir', default='dataset', type=str, + help="Path to dataset directory") + dataset_group.add_argument('--dataset', default='cifar10', type=str, + choices=('cifar10, cifar100', 'imagenet'), + help="Name of dataset") + dataset_group.add_argument('--crop-size', default=32, type=int, + help='Random crop size after resize') + dataset_group.add_argument('--crop-scale-range', nargs=2, default=(0.8, 1), + type=float, help='Random resize scale range', + metavar=('start', 'stop')) + dataset_group.add_argument('--hflip-prob', default=0.5, type=float, + help='Random horizontal flip probability') + dataset_group.add_argument('--distort-strength', default=0.5, type=float, + help='Distortion strength') + dataset_group.add_argument('--gauss-ker-scale', default=10, type=float, + help='Gaussian kernel scale factor ' + '(s = img_size / ker_size)') + dataset_group.add_argument('--gauss-sigma-range', nargs=2, default=(0.1, 2), + type=float, help='Random gaussian blur sigma range', + metavar=('start', 'stop')) + dataset_group.add_argument('--gauss-prob', default=0.5, type=float, + help='Random gaussian blur probability') + + dataloader_group = parser.add_argument_group('Dataloader parameters') + dataloader_group.add_argument('--batch-size', default=128, type=int, + help='Batch size') + dataloader_group.add_argument('--num-workers', default=2, type=int, + help='Number of dataloader processes') + + optim_group = parser.add_argument_group('Optimizer parameters') + optim_group.add_argument('--optim', default='lars', type=str, + choices=('adam', 'sgd', 'lars'), + help="Name of optimizer") + optim_group.add_argument('--lr', default=1., type=float, + help='Learning rate') + optim_group.add_argument('--betas', nargs=2, default=(0.9, 0.999), type=float, + help='Adam betas', metavar=('beta1', 'beta2')) + optim_group.add_argument('--momentum', default=0.9, type=float, + help='SDG momentum') + optim_group.add_argument('--weight-decay', default=1e-6, type=float, + help='Weight decay (l2 regularization)') + + sched_group = parser.add_argument_group('Scheduler parameters') + sched_group.add_argument('--sched', default='warmup-anneal', type=str, + choices=('const', None, 'linear', 'warmup-anneal'), + help="Name of scheduler") + sched_group.add_argument('--warmup-iters', default=2344, type=int, + help='Epochs for warmup (`warmup-anneal` scheduler only)') + + args = parser.parse_args() + if args.config: + config = yaml.safe_load(args.config) + args.__dict__ |= { + k: tuple(v) if isinstance(v, list) else v + for k, v in config.items() + } + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.codename) + args.log_dir = os.path.join(args.log_dir, args.codename) + + return args -from libs.logging import Loggers -from libs.utils import Trainer, BaseConfig +class PosReconTrainer(SimCLRTrainer): + def __init__( + self, + vit_config: dict, + fixed_pos: bool, + shuff_rate: float, + mask_rate: float, + *args, + **kwargs + ): + self.vit_config = vit_config + self.fixed_pos = fixed_pos + self.shuff_rate = shuff_rate + self.mask_rate = mask_rate + super(PosReconTrainer, self).__init__('vit', *args, **kwargs) -class PosReconTrainer(Trainer): - def __init__(self, *args, **kwargs): - super(PosReconTrainer, self).__init__(*args, **kwargs) + @dataclass + class BatchLogRecord(BaseBatchLogRecord): + lr: float | None + loss_recon_train: float | None + loss_contra_train: float | None + acc_contra_train: float | None + loss_recon_eval: float | None + loss_shuff_recon_eval: float | None + loss_contra_eval: float | None + acc_contra_eval: float | None + norm_patch_embed: float | None + norm_pos_embed: float | None + norm_features: float | None + norm_rep: float | None + norm_pos_hat: float | None + norm_embed: float | None - @staticmethod - def _prepare_dataset(dataset_config: BaseConfig.DatasetConfig) -> tuple[Dataset, Dataset]: - pass + def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + if dataset in {'cifar10', 'cifar100', 'cifar', + 'imagenet1k', 'imagenet'}: + model = simclr_pos_recon_vit(self.vit_config, self.hid_dim, + out_dim=self.out_dim, probe=True) + else: + raise NotImplementedError(f"Unimplemented dataset: '{dataset}") - @staticmethod - def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: - pass + yield 'model', model - @staticmethod - def _configure_optimizers(models: Iterable[tuple[str, torch.nn.Module]], optim_config: BaseConfig.OptimConfig) -> \ - Iterable[tuple[str, torch.optim.Optimizer]]: - pass + def _custom_init_fn(self, config: SimCLRConfig): + if not self.fixed_pos: + device = self.models['model'].device + pos_embed = self.models['model'].backbone.init_pos_embed(device, False) + self.models['pos_embed'] = pos_embed + self.optims['model_optim'].add_param_group({ + 'params': pos_embed, + 'weight_decay': config.optim_config.weight_decay, + 'layer_adaptation': True, + }) + self._auto_load_checkpoint(self._checkpoint_dir, self._inf_mode, + **(self.models | self.optims)) + self.scheds = dict(self._configure_scheduler( + self.optims.items(), self.restore_iter - 1, + self.num_iters, config.sched_config + )) + self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o + for n, o in self.optims.items()} def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): - pass + if self.fixed_pos: + model = self.models['model'] + pos_embed = model.backbone.init_pos_embed(device, True) + else: + model, pos_embed = self.models.values() + + optim = self.optims['model_optim'] + sched = self.scheds['model_optim_sched'] + train_loader = iter(self.train_loader) + + model.train() + for iter_ in range(self.restore_iter, num_iters): + input_, _ = next(train_loader) + input_1, input_2 = input_[0].to(device), input_[1].to(device) + pos_embed_clone = pos_embed.detach().clone() + target = pos_embed_clone.expand(input_1.size(0), -1, -1) + + # In-place shuffle positional embedding, dangerous but no better choice + pos_embed.data, *unshuff = model.backbone.shuffle_pos_embed(pos_embed_clone, self.shuff_rate) + visible_indices_1 = model.backbone.generate_masks(self.mask_rate) + visible_indices_2 = model.backbone.generate_masks(self.mask_rate) + + embed_1, pos_hat, rep, features, patch_embed = model(input_1, pos_embed, visible_indices_1) + embed_2, *_ = model(input_2, pos_embed, visible_indices_2) + embed = torch.cat([embed_1, embed_2]) + pos_hat = pos_hat.view(input_1.size(0), -1, model.backbone.embed_dim) + + visible_seq_indices_ex_cls = visible_indices_1 - model.backbone.num_prefix_tokens + recon_loss = F.smooth_l1_loss(pos_hat[:, visible_seq_indices_ex_cls, :], + target[:, visible_seq_indices_ex_cls, :]) + contra_loss, contra_acc = loss_fn(embed) + loss = recon_loss + contra_loss + + optim.zero_grad() + loss.backward() + optim.step() + + pos_embed.data = model.backbone.unshuffle_pos_embed(pos_embed.data, *unshuff) + + patch_embed_norm = patch_embed.norm(dim=-1).mean() + pos_embed_norm = pos_embed.norm(dim=-1).mean() + features_norm = features.norm(dim=-1).mean() + rep_norm = rep.norm(dim=-1).mean() + pos_hat_norm = pos_hat.norm(dim=-1).mean() + embed_norm = embed.norm(dim=-1).mean() + + self.log(logger, self.BatchLogRecord( + iter_, num_iters, iter_, iter_, num_iters, + optim.param_groups[0]['lr'], + recon_loss.item(), contra_loss.item(), contra_acc.item(), + loss_recon_eval=None, loss_shuff_recon_eval=None, loss_contra_eval=None, acc_contra_eval=None, + norm_patch_embed=patch_embed_norm.item(), norm_pos_embed=pos_embed_norm.item(), + norm_features=features_norm.item(), norm_rep=rep_norm.item(), + norm_pos_hat=pos_hat_norm.item(), norm_embed=embed_norm.item(), + )) + if (iter_ + 1) % (num_iters // 100) == 0: + metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) + recon_loss, shuffle_recon_loss, contra_loss, contra_acc = metrics + eval_log = self.BatchLogRecord( + iter_, num_iters, iter_, iter_, num_iters, lr=None, + loss_recon_train=None, loss_contra_train=None, acc_contra_train=None, + loss_recon_eval=recon_loss.item(), loss_shuff_recon_eval=shuffle_recon_loss.item(), + loss_contra_eval=contra_loss.item(), acc_contra_eval=contra_acc.item(), + norm_patch_embed=None, norm_pos_embed=None, norm_features=None, norm_rep=None, + norm_pos_hat=None, norm_embed=None, + ) + self.log(logger, eval_log) + self.save_checkpoint(eval_log) + model.train() + if sched is not None: + sched.step() def eval(self, loss_fn: Callable, device: torch.device): - pass + if self.fixed_pos: + model = self.models['model'] + pos_embed = model.backbone.init_pos_embed(device, True) + else: + model, pos_embed = self.models.values() + pos_embed_clone = pos_embed.detach().clone() + model.eval() + seq_len = pos_embed_clone.size(1) + shuffled_pos_embed = pos_embed_clone[:, torch.randperm(seq_len), :] + all_visible_indices = torch.arange(seq_len) + model.backbone.num_prefix_tokens + with torch.no_grad(): + for input_, _ in self.test_loader: + input_ = torch.cat(input_).to(device) + target = pos_embed_clone.expand(input_.size(0), -1, -1) + curr_bz = input_.size(0) + embed, pos_hat, *_ = model(input_, pos_embed_clone, all_visible_indices) + _, shuffled_pos_hat, *_ = model(input_, shuffled_pos_embed, all_visible_indices) + pos_hat = pos_hat.view(curr_bz, seq_len, -1) + shuffled_pos_hat = shuffled_pos_hat.view(curr_bz, seq_len, -1) + recon_loss = F.smooth_l1_loss(pos_hat, target) + shuffled_recon_loss = F.smooth_l1_loss(shuffled_pos_hat, target) + contra_loss, contra_acc = loss_fn(embed) + + yield (recon_loss.item(), shuffled_recon_loss.item(), + contra_loss.item(), contra_acc.item()) + + +if __name__ == '__main__': + args = parse_args_and_config() + config = SimCLRConfig.from_args(args) + img_size = config.dataset_config.crop_size + seq_len = (img_size // args.pat_sz) ** 2 + vit_config = dict( + img_size=img_size, + patch_size=args.pat_sz, + depth=args.nlayers, + num_heads=args.nheads, + embed_dim=args.embed_dim, + mlp_ratio=args.mlp_dim / args.embed_dim, + drop_rate=args.dropout, + num_classes=seq_len * args.embed_dim + args.hid_dim, + no_embed_class=True, + ) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + trainer = PosReconTrainer( + seed=args.seed, + checkpoint_dir=args.checkpoint_dir, + device=device, + inf_mode=True, + num_iters=args.num_iters, + config=config, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + mask_rate=args.mask_rate, + shuff_rate=args.shuff_rate, + vit_config=vit_config, + fixed_pos=args.fixed_pos, + ) + + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, device) |