diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 233 |
1 files changed, 141 insertions, 92 deletions
diff --git a/models/model.py b/models/model.py index 3f24936..25c8a4f 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,6 @@ +import copy import os -from datetime import datetime +import random from typing import Union, Optional import numpy as np @@ -47,10 +48,10 @@ class Model: self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta.get('restore_iter', 0) + self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) - self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) + self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True self.in_channels: int = 3 @@ -70,6 +71,7 @@ class Model: self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None self.image_log_on = system_config.get('image_log_on', False) + self.val_size = system_config.get('val_size', 10) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -83,7 +85,7 @@ class Model: @property def _model_sig(self) -> str: return '_'.join( - (self._model_name, str(self.curr_iter), str(self.total_iter)) + (self._model_name, str(self.curr_iter + 1), str(self.total_iter)) ) @property @@ -112,18 +114,18 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (curr_iter, total_iter, (condition, selector)) in zip( - self.curr_iters, self.total_iters, dataset_selectors.items() + for (restore_iter, total_iter, (condition, selector)) in zip( + self.restore_iters, self.total_iters, dataset_selectors.items() ): print(f'Training model {condition} ...') # Skip finished model - if curr_iter == total_iter: + if restore_iter == total_iter: continue # Check invalid restore iter - elif curr_iter > total_iter: + elif restore_iter > total_iter: raise ValueError("Restore iter '{}' should less than total " - "iter '{}'".format(curr_iter, total_iter)) - self.curr_iter = curr_iter + "iter '{}'".format(restore_iter, total_iter)) + self.restore_iter = self.curr_iter = restore_iter self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), @@ -136,70 +138,85 @@ class Model: dataloader_config: DataloaderConfiguration, ): self.is_train = True - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Validation dataset + # (the first `val_size` subjects from evaluation set) + val_dataset_config = copy.deepcopy(dataset_config) + train_size = dataset_config.get('train_size', 74) + val_dataset_config['train_size'] = train_size + self.val_size + val_dataset_config['selector']['classes'] = ClipClasses({ + str(c).zfill(3) + for c in range(train_size + 1, train_size + self.val_size + 1) + }) + val_dataset = self._parse_dataset_config(val_dataset_config) + val_dataloader = iter(self._parse_dataloader_config( + val_dataset, dataloader_config + )) + # Training dataset + train_dataset = self._parse_dataset_config(dataset_config) + train_dataloader = iter(self._parse_dataloader_config( + train_dataset, dataloader_config + )) # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() optim_hp: dict = self.hp.get('optimizer', {}).copy() sched_hp = self.hp.get('scheduler', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) - sched_gamma = sched_hp.get('gamma', 0.9) - sched_step_size = sched_hp.get('step_size', 500) - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda epoch: sched_gamma ** (epoch // sched_step_size), - ]) + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + all_step = self.total_iter - start_step + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda t: final_gamma ** ((t - start_step) / all_step) + if t > start_step else 1, + ) self.writer = SummaryWriter(self._log_name) + # Set seeds for reproducibility + random.seed(0) + torch.manual_seed(0) self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts + # Offset a iter to load last checkpoint + self.curr_iter -= 1 checkpoint = torch.load(self._checkpoint_name) - iter_, loss = checkpoint['iter'], checkpoint['loss'] - print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) + random.setstate(checkpoint['rand_states'][0]) + torch.set_rng_state(checkpoint['rand_states'][1]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start - start_time = datetime.now() - running_loss = torch.zeros(3, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", - f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'LR':^9}") - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter), + desc='Training'): + batch_c1, batch_c2 = next(train_dataloader) # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - losses, images = self.rgb_pn(x_c1, x_c2) + losses, features, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() + self.scheduler.step() - # Statistics and checkpoint - running_loss += losses.detach() - # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', - 'Canonical consistency loss', - 'Pose similarity loss' - ], losses)), self.curr_iter) - - if self.curr_iter % 100 == 0: - lr = self.scheduler.get_last_lr()[0] - # Write learning rates - self.writer.add_scalar( - 'Learning rate/Auto-encoder', lr, self.curr_iter - ) + # Learning rate + self.writer.add_scalar( + 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter + ) + # Other stats + self._write_stat('Train', loss, losses) + + if self.curr_iter % 100 == 99: # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images @@ -216,30 +233,54 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - time_used = datetime.now() - start_time - remaining_minute, second = divmod(time_used.seconds, 60) - hour, minute = divmod(remaining_minute, 60) - print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f}'.format(*running_loss / 100), - f'{lr:.3e}') - running_loss.zero_() - - # Step scheduler - self.scheduler.step() - - if self.curr_iter % 1000 == 0: + f_a, f_c, f_p = features + for i, (f_a_i, f_c_i, f_p_i) in enumerate( + zip(f_a, f_c, f_p) + ): + self.writer.add_images( + f'Appearance features/Layer {i}', + f_a_i[:, :3, :, :], self.curr_iter + ) + self.writer.add_images( + f'Canonical features/Layer {i}', + f_c_i[:, :3, :, :], self.curr_iter + ) + for j, p in enumerate(f_p_i): + self.writer.add_images( + f'Pose features/Layer {i}/batch{j}', + p[:, :3, :, :], self.curr_iter + ) + + # Calculate losses on testing batch + batch_c1, batch_c2 = next(val_dataloader) + x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) + with torch.no_grad(): + losses, _, _ = self.rgb_pn(x_c1, x_c2) + loss = losses.sum() + + self._write_stat('Val', loss, losses) + + # Checkpoint + if self.curr_iter % 1000 == 999: torch.save({ - 'iter': self.curr_iter, + 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'sched_state_dict': self.scheduler.state_dict(), - 'loss': loss, }, self._checkpoint_name) - if self.curr_iter == self.total_iter: - self.writer.close() - break + self.writer.close() + + def _write_stat( + self, postfix, loss, losses + ): + # Write losses to TensorBoard + self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) + self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( + 'Cross reconstruction loss', 'Canonical consistency loss', + 'Pose similarity loss' + ), losses)), self.curr_iter) def transform( self, @@ -248,12 +289,12 @@ class Model: dataset_selectors: dict[ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: DataloaderConfiguration + dataloader_config: DataloaderConfiguration, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -261,41 +302,45 @@ class Model: ) # Init models - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} - # Gallery - checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - for sample in tqdm(gallery_dataloader, - desc='Transforming gallery', unit='clips'): - gallery_samples.append(self._get_eval_sample(sample)) - gallery_samples = default_collate(gallery_samples) - # Probe - for (condition, dataloader) in probe_dataloaders.items(): + gallery_samples, probe_samples = {}, {} + for (condition, probe_dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + # Gallery + gallery_samples_c = [] + for sample in tqdm(gallery_dataloader, + desc=f'Transforming gallery {condition}', + unit='clips'): + gallery_samples_c.append(self._get_eval_sample(sample)) + gallery_samples[condition] = default_collate(gallery_samples_c) + # Probe probe_samples_c = [] - for sample in tqdm(dataloader, + for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) - x_c, x_p = self.rgb_pn(clip).detach() + label, condition, view, clip = sample.values() + with torch.no_grad(): + feature_c, feature_p = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'cano_feature': x_c, 'pose_feature': x_p} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': torch.cat((feature_c, feature_p)).view(-1) } def _load_pretrained( @@ -307,10 +352,11 @@ class Model: ] ) -> dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): - self.curr_iter = iter_ + self.curr_iter = iter_ - 1 + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -322,26 +368,29 @@ class Model: self, dataset_config: DatasetConfiguration, dataloader_config: DataloaderConfiguration, + is_train: bool = False ) -> tuple[DataLoader, dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config |