From c74df416b00f837ba051f3947be92f76e7afbd88 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 13:56:17 +0800 Subject: Code refactoring 1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers --- models/model.py | 119 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 67 insertions(+), 52 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index cb455f3..adea626 100644 --- a/models/model.py +++ b/models/model.py @@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss +from utils.triplet_loss import BatchTripletLoss class Model: @@ -69,7 +71,8 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.triplet_loss: Optional[JointBatchTripletLoss] = None + self.triplet_loss_hpm: Optional[BatchTripletLoss] = None + self.triplet_loss_pn: Optional[BatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -149,26 +152,28 @@ class Model: triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) + pn_optim_hp = optim_hp.pop('part_net', {}) sched_hp = self.hp.get('scheduler', {}) + ae_sched_hp = sched_hp.get('auto_encoder', {}) + hpm_sched_hp = sched_hp.get('hpm', {}) + pn_sched_hp = sched_hp.get('part_net', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Hard margins if triplet_margins: - # Same margins - if triplet_margins[0] == triplet_margins[1]: - self.triplet_loss = BatchTripletLoss( - triplet_is_hard, triplet_margins[0] - ) - else: # Different margins - self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, - triplet_is_hard, triplet_is_mean, triplet_margins - ) + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[0] + ) + self.triplet_loss_pn = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[1] + ) else: # Soft margins - self.triplet_loss = BatchTripletLoss( + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) + self.triplet_loss_pn = BatchTripletLoss( triplet_is_hard, triplet_is_mean, None ) @@ -177,25 +182,33 @@ class Model: # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.triplet_loss = self.triplet_loss.to(self.device) + self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) + self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) + self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, ], **optim_hp) - sched_final_gamma = sched_hp.get('final_gamma', 0.001) - sched_start_step = sched_hp.get('start_step', 15_000) - all_step = self.total_iter - sched_start_step - - def lr_lambda(epoch): - if epoch > sched_start_step: - passed_step = epoch - sched_start_step - return sched_final_gamma ** (passed_step / all_step) - else: - return 1 + + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + ae_start_step = ae_sched_hp.get('start_step', start_step) + ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) + ae_all_step = self.total_iter - ae_start_step + hpm_start_step = hpm_sched_hp.get('start_step', start_step) + hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) + hpm_all_step = self.total_iter - hpm_start_step + pn_start_step = pn_sched_hp.get('start_step', start_step) + pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) + pn_all_step = self.total_iter - pn_start_step self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lr_lambda, lr_lambda, lr_lambda, lr_lambda + lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) + if t > ae_start_step else 1, + lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) + if t > hpm_start_step else 1, + lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) + if t > pn_start_step else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -220,7 +233,7 @@ class Model: running_loss = torch.zeros(5, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") + f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -228,17 +241,20 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) - losses = torch.cat(( - ae_losses, - torch.stack(( - trip_loss[:self.rgb_pn.hpm_num_parts].mean(), - trip_loss[self.rgb_pn.hpm_num_parts:].mean() - )) + y = y.repeat(self.rgb_pn.num_parts, 1) + trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( + embedding_c, y[:self.rgb_pn.hpm.num_parts] + ) + trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( + embedding_p, y[self.rgb_pn.hpm.num_parts:] + ) + losses = torch.stack(( + *ae_losses, + trip_loss_hpm.mean(), + trip_loss_pn.mean() )) loss = losses.sum() loss.backward() @@ -257,30 +273,30 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if num_non_zero is not None: + if hpm_num_non_zero is not None and hpm_num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + 'HPM': hpm_num_non_zero.mean(), + 'PartNet': pn_num_non_zero.mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_dist = hpm_dist.mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_dist = pn_dist.mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_embedding = embedding_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_embedding = embedding_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -288,10 +304,9 @@ class Model: ) # Learning rate lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) + self.writer.add_scalars('Learning rate', dict(zip(( + 'Auto-encoder', 'HPM', 'PartNet' + ), lrs)), self.curr_iter) if self.curr_iter % 100 == 0: # Write disentangled images @@ -316,7 +331,7 @@ class Model: print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - f'{lrs[0]:.3e}') + '{:.3e} {:.3e} {:.3e}'.format(*lrs)) running_loss.zero_() # Step scheduler @@ -548,7 +563,7 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): + elif isinstance(m, (HorizontalPyramidMatching, PartNet)): nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( -- cgit v1.2.3 From 2ea916b2a963eae7d47151b41c8c78a578c402e2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 15:31:44 +0800 Subject: Make evaluate method static --- models/model.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index adea626..b09d600 100644 --- a/models/model.py +++ b/models/model.py @@ -241,15 +241,15 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_parts, 1) trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embedding_c, y[:self.rgb_pn.hpm.num_parts] + embed_c, y[:self.rgb_pn.hpm.num_parts] ) trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embedding_p, y[self.rgb_pn.hpm.num_parts:] + embed_p, y[self.rgb_pn.hpm.num_parts:] ) losses = torch.stack(( *ae_losses, @@ -290,13 +290,13 @@ class Model: num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding_c.mean(0) + mean_hpm_embedding = embed_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding_p.mean(0) + mean_pa_embedding = embed_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -425,13 +425,16 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) + gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c return gallery_samples, probe_samples @@ -446,15 +449,15 @@ class Model: **{'feature': feature} } + @staticmethod def evaluate( - self, - gallery_samples: dict[str, Union[list[str], torch.Tensor]], - probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], + gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]], + probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]], num_ranks: int = 5 ) -> dict[str, torch.Tensor]: - conditions = gallery_samples.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + conditions = list(probe_samples.keys()) + gallery_views_meta = gallery_samples['meta']['views'] + probe_views_meta = probe_samples[conditions[0]]['meta']['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks @@ -467,7 +470,7 @@ class Model: (labels_g, _, views_g, features_g) = gallery_samples_c.values() views_g = np.asarray(views_g) probe_samples_c = probe_samples[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() + (labels_p, _, views_p, features_p, _) = probe_samples_c.values() views_p = np.asarray(views_p) accuracy_c = accuracy[condition] for (v_g_i, view_g) in enumerate(gallery_views_meta): @@ -492,7 +495,6 @@ class Model: positive_counts = positive_mat.sum(0) total_counts, _ = dist.size() accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - return accuracy def _load_pretrained( -- cgit v1.2.3