aboutsummaryrefslogtreecommitdiff
path: root/posrecon/main.py
blob: c5c946d7a72037a226d7269be1e3101ac7789174 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable

import argparse

import os
import sys
import torch
import torch.nn.functional as F
import yaml

path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)

from libs.criteria import InfoNCELoss
from libs.logging import Loggers, BaseBatchLogRecord
from libs.optimizers import LARS
from posrecon.models import simclr_pos_recon_vit
from simclr.main import SimCLRTrainer, SimCLRConfig


def parse_args_and_config():
    parser = argparse.ArgumentParser(
        description='SimCLR w/ positional reconstruction',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--codename', default='cifar10-pos-recon-simclr-128-lars-warmup',
                        type=str, help="Model descriptor")
    parser.add_argument('--log-dir', default='logs', type=str,
                        help="Path to log directory")
    parser.add_argument('--checkpoint-dir', default='checkpoints', type=str,
                        help="Path to checkpoints directory")
    parser.add_argument('--seed', default=None, type=int,
                        help='Random seed for reproducibility')
    parser.add_argument('--num-iters', default=23438, type=int,
                        help='Number of iters (default is 50 epochs equiv., '
                             'around dataset_size * epochs / batch_size)')
    parser.add_argument('--config', type=argparse.FileType(mode='r'),
                        help='Path to config file (optional)')

    # TODO: Add model hyperparams dataclass
    parser.add_argument('--hid-dim', default=2048, type=int,
                        help='Number of dimension of embedding')
    parser.add_argument('--out-dim', default=128, type=int,
                        help='Number of dimension after projection')
    parser.add_argument('--temp', default=0.5, type=float,
                        help='Temperature in InfoNCE loss')

    parser.add_argument('--pat-sz', default=16, type=int,
                        help='Size of image patches')
    parser.add_argument('--nlayers', default=7, type=int,
                        help='Depth of Transformer blocks')
    parser.add_argument('--nheads', default=12, type=int,
                        help='Number of attention heads')
    parser.add_argument('--embed-dim', default=384, type=int,
                        help='Number of ViT embedding dimension')
    parser.add_argument('--mlp-dim', default=384, type=int,
                        help='Number of MLP dimension')
    parser.add_argument('--dropout', default=0., type=float,
                        help='MLP dropout rate')

    parser.add_argument('--fixed-pos', default=False,
                        action=argparse.BooleanOptionalAction,
                        help='Fixed or learned positional embedding')
    parser.add_argument('--shuff-rate', default=0.75, type=float,
                        help='Ratio of shuffling sequence')
    parser.add_argument('--mask-rate', default=0.75, type=float,
                        help='Ratio of masking sequence')

    dataset_group = parser.add_argument_group('Dataset parameters')
    dataset_group.add_argument('--dataset-dir', default='dataset', type=str,
                               help="Path to dataset directory")
    dataset_group.add_argument('--dataset', default='cifar10', type=str,
                               choices=('cifar10, cifar100', 'imagenet'),
                               help="Name of dataset")
    dataset_group.add_argument('--crop-size', default=32, type=int,
                               help='Random crop size after resize')
    dataset_group.add_argument('--crop-scale-range', nargs=2, default=(0.8, 1),
                               type=float, help='Random resize scale range',
                               metavar=('start', 'stop'))
    dataset_group.add_argument('--hflip-prob', default=0.5, type=float,
                               help='Random horizontal flip probability')
    dataset_group.add_argument('--distort-strength', default=0.5, type=float,
                               help='Distortion strength')
    dataset_group.add_argument('--gauss-ker-scale', default=10, type=float,
                               help='Gaussian kernel scale factor '
                                    '(s = img_size / ker_size)')
    dataset_group.add_argument('--gauss-sigma-range', nargs=2, default=(0.1, 2),
                               type=float, help='Random gaussian blur sigma range',
                               metavar=('start', 'stop'))
    dataset_group.add_argument('--gauss-prob', default=0.5, type=float,
                               help='Random gaussian blur probability')

    dataloader_group = parser.add_argument_group('Dataloader parameters')
    dataloader_group.add_argument('--batch-size', default=128, type=int,
                                  help='Batch size')
    dataloader_group.add_argument('--num-workers', default=2, type=int,
                                  help='Number of dataloader processes')

    optim_group = parser.add_argument_group('Optimizer parameters')
    optim_group.add_argument('--optim', default='lars', type=str,
                             choices=('adam', 'sgd', 'lars'),
                             help="Name of optimizer")
    optim_group.add_argument('--lr', default=1., type=float,
                             help='Learning rate')
    optim_group.add_argument('--betas', nargs=2, default=(0.9, 0.999), type=float,
                             help='Adam betas', metavar=('beta1', 'beta2'))
    optim_group.add_argument('--momentum', default=0.9, type=float,
                             help='SDG momentum')
    optim_group.add_argument('--weight-decay', default=1e-6, type=float,
                             help='Weight decay (l2 regularization)')

    sched_group = parser.add_argument_group('Scheduler parameters')
    sched_group.add_argument('--sched', default='warmup-anneal', type=str,
                             choices=('const', None, 'linear', 'warmup-anneal'),
                             help="Name of scheduler")
    sched_group.add_argument('--warmup-iters', default=2344, type=int,
                             help='Epochs for warmup (`warmup-anneal` scheduler only)')

    args = parser.parse_args()
    if args.config:
        config = yaml.safe_load(args.config)
        args.__dict__ |= {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in config.items()
        }
    args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.codename)
    args.log_dir = os.path.join(args.log_dir, args.codename)

    return args


class PosReconTrainer(SimCLRTrainer):
    def __init__(
            self,
            vit_config: dict,
            fixed_pos: bool,
            shuff_rate: float,
            mask_rate: float,
            *args,
            **kwargs
    ):
        self.vit_config = vit_config
        self.fixed_pos = fixed_pos
        self.shuff_rate = shuff_rate
        self.mask_rate = mask_rate
        super(PosReconTrainer, self).__init__('vit', *args, **kwargs)

    @dataclass
    class BatchLogRecord(BaseBatchLogRecord):
        lr: float | None
        loss_recon_train: float | None
        loss_contra_train: float | None
        acc_contra_train: float | None
        loss_recon_eval: float | None
        loss_shuff_recon_eval: float | None
        loss_contra_eval: float | None
        acc_contra_eval: float | None
        norm_patch_embed: float | None
        norm_pos_embed: float | None
        norm_features: float | None
        norm_rep: float | None
        norm_pos_hat: float | None
        norm_embed: float | None

    def _init_models(self, dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
        if dataset in {'cifar10', 'cifar100', 'cifar',
                       'imagenet1k', 'imagenet'}:
            model = simclr_pos_recon_vit(self.vit_config, self.hid_dim,
                                         out_dim=self.out_dim, probe=True)
        else:
            raise NotImplementedError(f"Unimplemented dataset: '{dataset}")

        yield 'model', model

    def _custom_init_fn(self, config: SimCLRConfig):
        if not self.fixed_pos:
            device = self.models['model'].device
            pos_embed = self.models['model'].backbone.init_pos_embed(device, False)
            self.models['pos_embed'] = pos_embed
            self.optims['model_optim'].add_param_group({
                'params': pos_embed,
                'weight_decay': config.optim_config.weight_decay,
                'layer_adaptation': True,
            })
            self._auto_load_checkpoint(self._checkpoint_dir, self._inf_mode,
                                       **(self.models | self.optims))
            self.scheds = dict(self._configure_scheduler(
                self.optims.items(), self.restore_iter - 1,
                self.num_iters, config.sched_config
            ))
        self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o
                       for n, o in self.optims.items()}

    def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
        if self.fixed_pos:
            model = self.models['model']
            pos_embed = model.backbone.init_pos_embed(device, True)
        else:
            model, pos_embed = self.models.values()

        optim = self.optims['model_optim']
        sched = self.scheds['model_optim_sched']
        train_loader = iter(self.train_loader)

        model.train()
        for iter_ in range(self.restore_iter, num_iters):
            input_, _ = next(train_loader)
            input_1, input_2 = input_[0].to(device), input_[1].to(device)
            pos_embed_clone = pos_embed.detach().clone()
            target = pos_embed_clone.expand(input_1.size(0), -1, -1)

            # In-place shuffle positional embedding, dangerous but no better choice
            pos_embed.data, *unshuff = model.backbone.shuffle_pos_embed(pos_embed_clone, self.shuff_rate)
            visible_indices_1 = model.backbone.generate_masks(self.mask_rate)
            visible_indices_2 = model.backbone.generate_masks(self.mask_rate)

            embed_1, pos_hat, rep, features, patch_embed = model(input_1, pos_embed, visible_indices_1)
            embed_2, *_ = model(input_2, pos_embed, visible_indices_2)
            embed = torch.cat([embed_1, embed_2])
            pos_hat = pos_hat.view(input_1.size(0), -1, model.backbone.embed_dim)

            visible_seq_indices_ex_cls = visible_indices_1 - model.backbone.num_prefix_tokens
            recon_loss = F.smooth_l1_loss(pos_hat[:, visible_seq_indices_ex_cls, :],
                                          target[:, visible_seq_indices_ex_cls, :])
            contra_loss, contra_acc = loss_fn(embed)
            loss = recon_loss + contra_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            pos_embed.data = model.backbone.unshuffle_pos_embed(pos_embed.data, *unshuff)

            patch_embed_norm = patch_embed.norm(dim=-1).mean()
            pos_embed_norm = pos_embed.norm(dim=-1).mean()
            features_norm = features.norm(dim=-1).mean()
            rep_norm = rep.norm(dim=-1).mean()
            pos_hat_norm = pos_hat.norm(dim=-1).mean()
            embed_norm = embed.norm(dim=-1).mean()

            self.log(logger, self.BatchLogRecord(
                iter_, num_iters, iter_, iter_, num_iters,
                optim.param_groups[0]['lr'],
                recon_loss.item(), contra_loss.item(), contra_acc.item(),
                loss_recon_eval=None, loss_shuff_recon_eval=None, loss_contra_eval=None, acc_contra_eval=None,
                norm_patch_embed=patch_embed_norm.item(), norm_pos_embed=pos_embed_norm.item(),
                norm_features=features_norm.item(), norm_rep=rep_norm.item(),
                norm_pos_hat=pos_hat_norm.item(), norm_embed=embed_norm.item(),
            ))
            if (iter_ + 1) % (num_iters // 100) == 0:
                metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
                recon_loss, shuffle_recon_loss, contra_loss, contra_acc = metrics
                eval_log = self.BatchLogRecord(
                    iter_, num_iters, iter_, iter_, num_iters, lr=None,
                    loss_recon_train=None, loss_contra_train=None, acc_contra_train=None,
                    loss_recon_eval=recon_loss.item(), loss_shuff_recon_eval=shuffle_recon_loss.item(),
                    loss_contra_eval=contra_loss.item(), acc_contra_eval=contra_acc.item(),
                    norm_patch_embed=None, norm_pos_embed=None, norm_features=None, norm_rep=None,
                    norm_pos_hat=None, norm_embed=None,
                )
                self.log(logger, eval_log)
                self.save_checkpoint(eval_log)
                model.train()
            if sched is not None:
                sched.step()

    def eval(self, loss_fn: Callable, device: torch.device):
        if self.fixed_pos:
            model = self.models['model']
            pos_embed = model.backbone.init_pos_embed(device, True)
        else:
            model, pos_embed = self.models.values()
        pos_embed_clone = pos_embed.detach().clone()
        model.eval()
        seq_len = pos_embed_clone.size(1)
        shuffled_pos_embed = pos_embed_clone[:, torch.randperm(seq_len), :]
        all_visible_indices = torch.arange(seq_len) + model.backbone.num_prefix_tokens
        with torch.no_grad():
            for input_, _ in self.test_loader:
                input_ = torch.cat(input_).to(device)
                target = pos_embed_clone.expand(input_.size(0), -1, -1)
                curr_bz = input_.size(0)
                embed, pos_hat, *_ = model(input_, pos_embed_clone, all_visible_indices)
                _, shuffled_pos_hat, *_ = model(input_, shuffled_pos_embed, all_visible_indices)
                pos_hat = pos_hat.view(curr_bz, seq_len, -1)
                shuffled_pos_hat = shuffled_pos_hat.view(curr_bz, seq_len, -1)
                recon_loss = F.smooth_l1_loss(pos_hat, target)
                shuffled_recon_loss = F.smooth_l1_loss(shuffled_pos_hat, target)
                contra_loss, contra_acc = loss_fn(embed)

                yield (recon_loss.item(), shuffled_recon_loss.item(),
                       contra_loss.item(), contra_acc.item())


if __name__ == '__main__':
    args = parse_args_and_config()
    config = SimCLRConfig.from_args(args)
    img_size = config.dataset_config.crop_size
    seq_len = (img_size // args.pat_sz) ** 2
    vit_config = dict(
        img_size=img_size,
        patch_size=args.pat_sz,
        depth=args.nlayers,
        num_heads=args.nheads,
        embed_dim=args.embed_dim,
        mlp_ratio=args.mlp_dim / args.embed_dim,
        drop_rate=args.dropout,
        num_classes=seq_len * args.embed_dim + args.hid_dim,
        no_embed_class=True,
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trainer = PosReconTrainer(
        seed=args.seed,
        checkpoint_dir=args.checkpoint_dir,
        device=device,
        inf_mode=True,
        num_iters=args.num_iters,
        config=config,
        hid_dim=args.hid_dim,
        out_dim=args.out_dim,
        mask_rate=args.mask_rate,
        shuff_rate=args.shuff_rate,
        vit_config=vit_config,
        fixed_pos=args.fixed_pos,
    )

    loggers = trainer.init_logger(args.log_dir)
    trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, device)