diff options
-rw-r--r-- | config.py | 2 | ||||
-rw-r--r-- | models/model.py | 249 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 | ||||
-rw-r--r-- | utils/configuration.py | 1 | ||||
-rw-r--r-- | utils/triplet_loss.py | 9 |
5 files changed, 165 insertions, 98 deletions
@@ -19,6 +19,8 @@ config: Configuration = { 'root_dir': 'data/CASIA-B-MRCNN-V2/SEG', # The number of subjects for training 'train_size': 74, + # The number of subjects for validating (Part of testing set) + 'val_size': 10, # Number of sampled frames per sequence (Training only) 'num_sampled_frames': 30, # Truncate clips longer than `truncate_threshold` diff --git a/models/model.py b/models/model.py index 36a4f7f..766e513 100644 --- a/models/model.py +++ b/models/model.py @@ -1,6 +1,6 @@ +import copy import os import random -from datetime import datetime from typing import Union, Optional import numpy as np @@ -52,9 +52,9 @@ class Model: self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta.get('restore_iter', 0) + self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,)) self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True @@ -62,6 +62,8 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None + self.num_pairs: Optional[int] = None + self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -90,7 +92,7 @@ class Model: @property def _model_sig(self) -> str: return '_'.join( - (self._model_name, str(self.curr_iter), str(self.total_iter)) + (self._model_name, str(self.curr_iter + 1), str(self.total_iter)) ) @property @@ -119,18 +121,18 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (curr_iter, total_iter, (condition, selector)) in zip( - self.curr_iters, self.total_iters, dataset_selectors.items() + for (restore_iter, total_iter, (condition, selector)) in zip( + self.restore_iters, self.total_iters, dataset_selectors.items() ): print(f'Training model {condition} ...') # Skip finished model - if curr_iter == total_iter: + if restore_iter == total_iter: continue # Check invalid restore iter - elif curr_iter > total_iter: + elif restore_iter > total_iter: raise ValueError("Restore iter '{}' should less than total " - "iter '{}'".format(curr_iter, total_iter)) - self.curr_iter = curr_iter + "iter '{}'".format(restore_iter, total_iter)) + self.restore_iter = self.curr_iter = restore_iter self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), @@ -143,8 +145,24 @@ class Model: dataloader_config: DataloaderConfiguration, ): self.is_train = True - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Validation dataset + # (the first `val_size` subjects from evaluation set) + val_size = dataset_config.pop('val_size', 10) + val_dataset_config = copy.deepcopy(dataset_config) + train_size = dataset_config.get('train_size', 74) + val_dataset_config['train_size'] = train_size + val_size + val_dataset_config['selector']['classes'] = ClipClasses({ + str(c).zfill(3) for c in range(train_size, train_size + val_size) + }) + val_dataset = self._parse_dataset_config(val_dataset_config) + val_dataloader = iter(self._parse_dataloader_config( + val_dataset, dataloader_config + )) + # Training dataset + train_dataset = self._parse_dataset_config(dataset_config) + train_dataloader = iter(self._parse_dataloader_config( + train_dataset, dataloader_config + )) # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() triplet_is_hard = model_hp.pop('triplet_is_hard', True) @@ -177,8 +195,8 @@ class Model: triplet_is_hard, triplet_is_mean, None ) - num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -191,6 +209,7 @@ class Model: {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, ], **optim_hp) + # Scheduler start_step = sched_hp.get('start_step', 15_000) final_gamma = sched_hp.get('final_gamma', 0.001) ae_start_step = ae_sched_hp.get('start_step', start_step) @@ -221,6 +240,8 @@ class Model: if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts + # Offset a iter to load last checkpoint + self.curr_iter -= 1 checkpoint = torch.load(self._checkpoint_name) random.setstate(checkpoint['rand_states'][0]) torch.set_rng_state(checkpoint['rand_states'][1]) @@ -229,13 +250,9 @@ class Model: self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start - start_time = datetime.now() - running_loss = torch.zeros(5, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", - f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter), + desc='Training'): + batch_c1, batch_c2 = next(train_dataloader) # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize @@ -243,72 +260,24 @@ class Model: x_c2 = batch_c2['clip'].to(self.device) embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) - # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_parts, 1) - trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embed_c, y[:self.rgb_pn.hpm.num_parts] + losses, hpm_result, pn_result = self._classification_loss( + embed_c, embed_p, ae_losses, y ) - trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embed_p, y[self.rgb_pn.hpm.num_parts:] - ) - losses = torch.stack(( - *ae_losses, - trip_loss_hpm.mean(), - trip_loss_pn.mean() - )) loss = losses.sum() loss.backward() self.optimizer.step() + self.scheduler.step() - # Statistics and checkpoint - running_loss += losses.detach() - # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/disentanglement', dict(zip(( - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss' - ), ae_losses)), self.curr_iter) - self.writer.add_scalars('Loss/triplet loss', { - 'HPM': losses[3], - 'PartNet': losses[4] - }, self.curr_iter) - # None-zero losses in batch - if hpm_num_non_zero is not None and pn_num_non_zero is not None: - self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': hpm_num_non_zero.mean(), - 'PartNet': pn_num_non_zero.mean() - }, self.curr_iter) - # Embedding distance - mean_hpm_dist = hpm_dist.mean(0) - self._add_ranked_scalars( - 'Embedding/HPM distance', mean_hpm_dist, - num_pos_pairs, num_pairs, self.curr_iter - ) - mean_pa_dist = pn_dist.mean(0) - self._add_ranked_scalars( - 'Embedding/ParNet distance', mean_pa_dist, - num_pos_pairs, num_pairs, self.curr_iter - ) - # Embedding norm - mean_hpm_embedding = embed_c.mean(0) - mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) - self._add_ranked_scalars( - 'Embedding/HPM norm', mean_hpm_norm, - self.k, self.pr * self.k, self.curr_iter - ) - mean_pa_embedding = embed_p.mean(0) - mean_pa_norm = mean_pa_embedding.norm(dim=-1) - self._add_ranked_scalars( - 'Embedding/PartNet norm', mean_pa_norm, - self.k, self.pr * self.k, self.curr_iter - ) # Learning rate - lrs = self.scheduler.get_last_lr() self.writer.add_scalars('Learning rate', dict(zip(( 'Auto-encoder', 'HPM', 'PartNet' - ), lrs)), self.curr_iter) + ), self.scheduler.get_last_lr())), self.curr_iter) + # Other stats + self._write_stat( + 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses + ) - if self.curr_iter % 100 == 0: + if self.curr_iter % 100 == 99: # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images @@ -325,19 +294,35 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - time_used = datetime.now() - start_time - remaining_minute, second = divmod(time_used.seconds, 60) - hour, minute = divmod(remaining_minute, 60) - print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e} {:.3e}'.format(*lrs)) - running_loss.zero_() - - # Step scheduler - self.scheduler.step() - if self.curr_iter % 1000 == 0: + # Validation + embed_c = self._flatten_embedding(embed_c) + embed_p = self._flatten_embedding(embed_p) + self._write_embedding('HPM Train', embed_c, x_c1, y) + self._write_embedding('PartNet Train', embed_p, x_c1, y) + + # Calculate losses on testing batch + batch_c1, batch_c2 = next(val_dataloader) + x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) + with torch.no_grad(): + embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2) + y = batch_c1['label'].to(self.device) + losses, hpm_result, pn_result = self._classification_loss( + embed_c, embed_p, ae_losses, y + ) + loss = losses.sum() + + self._write_stat( + 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses + ) + embed_c = self._flatten_embedding(embed_c) + embed_p = self._flatten_embedding(embed_p) + self._write_embedding('HPM Val', embed_c, x_c1, y) + self._write_embedding('PartNet Val', embed_p, x_c1, y) + + # Checkpoint + if self.curr_iter % 1000 == 999: torch.save({ 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), @@ -345,9 +330,83 @@ class Model: 'sched_state_dict': self.scheduler.state_dict(), }, self._checkpoint_name) - if self.curr_iter == self.total_iter: - self.writer.close() - break + self.writer.close() + + def _classification_loss(self, embed_c, embed_p, ae_losses, y): + # Duplicate labels for each part + y_triplet = y.repeat(self.rgb_pn.num_parts, 1) + hpm_result = self.triplet_loss_hpm( + embed_c, y_triplet[:self.rgb_pn.hpm.num_parts] + ) + pn_result = self.triplet_loss_pn( + embed_p, y_triplet[self.rgb_pn.hpm.num_parts:] + ) + losses = torch.stack(( + *ae_losses, + hpm_result.pop('loss').mean(), + pn_result.pop('loss').mean() + )) + return losses, hpm_result, pn_result + + def _write_embedding(self, tag, embed, x, y): + frame = x[:, 0, :, :, :].cpu() + n, c, h, w = frame.size() + padding = torch.zeros(n, c, h, (h-w) // 2) + padded_frame = torch.cat((padding, frame, padding), dim=-1) + self.writer.add_embedding( + embed, + metadata=y.cpu().tolist(), + label_img=padded_frame, + global_step=self.curr_iter, + tag=tag + ) + + def _flatten_embedding(self, embed): + return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1) + + def _write_stat( + self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses + ): + # Write losses to TensorBoard + self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) + self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( + 'Cross reconstruction loss', 'Canonical consistency loss', + 'Pose similarity loss' + ), losses[:3])), self.curr_iter) + self.writer.add_scalars(f'Loss/triplet loss {postfix}', { + 'HPM': losses[3], + 'PartNet': losses[4] + }, self.curr_iter) + # None-zero losses in batch + if hpm_result['counts'] is not None and pn_result['counts'] is not None: + self.writer.add_scalars(f'Loss/non-zero counts {postfix}', { + 'HPM': hpm_result['counts'].mean(), + 'PartNet': pn_result['counts'].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = hpm_result['dist'].mean(0) + self._add_ranked_scalars( + f'Embedding/HPM distance {postfix}', mean_hpm_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + mean_pn_dist = pn_result['dist'].mean(0) + self._add_ranked_scalars( + f'Embedding/ParNet distance {postfix}', mean_pn_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embed_c.mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + f'Embedding/HPM norm {postfix}', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embed_p.mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + f'Embedding/PartNet norm {postfix}', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) def _add_ranked_scalars( self, @@ -441,12 +500,12 @@ class Model: def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip.to(self.device)) + feature_c, feature_p = self.rgb_pn(clip.to(self.device)) return { 'label': label.item(), 'condition': condition[0], 'view': view[0], - 'feature': feature + 'feature': torch.cat((feature_c, feature_p)).view(-1) } @staticmethod diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index c136040..4a82da3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -57,7 +57,7 @@ class RGBPartNet(nn.Module): if self.training: return x_c, x_p, ae_losses, images else: - return torch.cat((x_c.view(-1), x_p.view(-1))) + return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/utils/configuration.py b/utils/configuration.py index f6ac182..157d249 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -14,6 +14,7 @@ class DatasetConfiguration(TypedDict): name: str root_dir: str train_size: int + val_size: int num_sampled_frames: int truncate_threshold: int discard_threshold: int diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 03fff21..5e3a97a 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module): else: # is_all positive_negative_dist = self._all_distance(dist, y, p, n) + non_zero_counts = None if self.margin: losses = F.relu(self.margin + positive_negative_dist).view(p, -1) non_zero_counts = (losses != 0).sum(1).float() @@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module): loss_metric = self._none_zero_mean(losses, non_zero_counts) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, non_zero_counts else: # Soft margin losses = F.softplus(positive_negative_dist).view(p, -1) if self.is_mean: loss_metric = losses.mean(1) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, None + + return { + 'loss': loss_metric, + 'dist': flat_dist, + 'counts': non_zero_counts + } @staticmethod def _batch_distance(x): |