From b294b715ec0de6ba94199f3b068dc828095fd2f1 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 10 Apr 2021 22:34:25 +0800 Subject: Calculate pose similarity loss and canonical consistency loss of each part after pooling --- models/model.py | 99 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 59 insertions(+), 40 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 45067e6..6118bdf 100644 --- a/models/model.py +++ b/models/model.py @@ -267,11 +267,15 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding, f_loss, images = self.rgb_pn(x_c1, x_c2) + ae_losses = self._disentangling_loss(x_c1, f_loss) y = batch_c1['label'].to(self.device) - losses, hpm_result, pn_result = self._classification_loss( - embed_c, embed_p, ae_losses, y - ) + results = self._classification_loss(embedding, y) + losses = torch.stack(( + *ae_losses, + results[0]['loss'].mean(), + results[1]['loss'].mean() + )) loss = losses.sum() loss.backward() self.optimizer.step() @@ -282,9 +286,7 @@ class Model: 'Auto-encoder', 'HPM', 'PartNet' ), self.scheduler.get_last_lr())), self.curr_iter) # Other stats - self._write_stat( - 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses - ) + self._write_stat('Train', embedding, results, losses) # Write disentangled images if self.image_log_on and self.curr_iter % self.image_log_steps \ @@ -306,8 +308,8 @@ class Model: if self.curr_iter % 100 == 99: # Validation - embed_c = self._flatten_embedding(embed_c) - embed_p = self._flatten_embedding(embed_p) + embed_c = self._flatten_embedding(embedding[0]) + embed_p = self._flatten_embedding(embedding[1]) self._write_embedding('HPM Train', embed_c, x_c1, y) self._write_embedding('PartNet Train', embed_p, x_c1, y) @@ -316,18 +318,19 @@ class Model: x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) with torch.no_grad(): - embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2) + embedding, f_loss, images = self.rgb_pn(x_c1, x_c2) + ae_losses = self._disentangling_loss(x_c1, f_loss) y = batch_c1['label'].to(self.device) - losses, hpm_result, pn_result = self._classification_loss( - embed_c, embed_p, ae_losses, y - ) - loss = losses.sum() - - self._write_stat( - 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses - ) - embed_c = self._flatten_embedding(embed_c) - embed_p = self._flatten_embedding(embed_p) + results = self._classification_loss(embedding, y) + losses = torch.stack(( + *ae_losses, + results[0]['loss'].mean(), + results[1]['loss'].mean() + )) + + self._write_stat('Val', embedding, results, losses) + embed_c = self._flatten_embedding(embedding[0]) + embed_p = self._flatten_embedding(embedding[1]) self._write_embedding('HPM Val', embed_c, x_c1, y) self._write_embedding('PartNet Val', embed_p, x_c1, y) @@ -342,21 +345,39 @@ class Model: self.writer.close() - def _classification_loss(self, embed_c, embed_p, ae_losses, y): + @staticmethod + def _disentangling_loss(x_c1_t2, f_loss): + n, t, c, h, w = x_c1_t2.size() + x_c1_t2_pred = f_loss[0] + xrecon_loss = torch.stack([ + F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :]) + for i in range(t) + ]).sum() + cano_cons_loss = torch.stack([ + torch.stack([ + F.mse_loss(_f_c_c1_t1[:, i, :], _f_c_c1_t2[:, i, :]) + + F.mse_loss(_f_c_c1_t2[:, i, :], _f_c_c2_t2[:, i, :]) + for i in range(t) + ]).mean() + for _f_c_c1_t2, _f_c_c1_t1, _f_c_c2_t2 in zip(*f_loss[1]) + ]).sum() + pose_sim_loss = torch.stack([ + F.mse_loss(_f_p_c1_t2.mean(1), _f_p_c2_t2.mean(1)) + for _f_p_c1_t2, _f_p_c2_t2 in zip(*f_loss[2]) + ]).sum() + + return xrecon_loss, cano_cons_loss * 10, pose_sim_loss * 100 + + def _classification_loss(self, embedding, y): # Duplicate labels for each part y_triplet = y.repeat(self.rgb_pn.num_parts, 1) hpm_result = self.triplet_loss_hpm( - embed_c, y_triplet[:self.rgb_pn.hpm.num_parts] + embedding[0], y_triplet[:self.rgb_pn.hpm.num_parts] ) pn_result = self.triplet_loss_pn( - embed_p, y_triplet[self.rgb_pn.hpm.num_parts:] + embedding[1], y_triplet[self.rgb_pn.hpm.num_parts:] ) - losses = torch.stack(( - *ae_losses, - hpm_result.pop('loss').mean(), - pn_result.pop('loss').mean() - )) - return losses, hpm_result, pn_result + return hpm_result, pn_result def _write_embedding(self, tag, embed, x, y): frame = x[:, 0, :, :, :].cpu() @@ -374,11 +395,8 @@ class Model: def _flatten_embedding(self, embed): return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1) - def _write_stat( - self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses - ): + def _write_stat(self, postfix, embeddings, results, losses): # Write losses to TensorBoard - self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( 'Cross reconstruction loss', 'Canonical consistency loss', 'Pose similarity loss' @@ -388,30 +406,31 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if hpm_result['counts'] is not None and pn_result['counts'] is not None: + if results[0]['counts'] is not None \ + and results[1]['counts'] is not None: self.writer.add_scalars(f'Loss/non-zero counts {postfix}', { - 'HPM': hpm_result['counts'].mean(), - 'PartNet': pn_result['counts'].mean() + 'HPM': results[0]['counts'].mean(), + 'PartNet': results[1]['counts'].mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = hpm_result['dist'].mean(0) + mean_hpm_dist = results[0]['dist'].mean(0) self._add_ranked_scalars( f'Embedding/HPM distance {postfix}', mean_hpm_dist, self.num_pos_pairs, self.num_pairs, self.curr_iter ) - mean_pn_dist = pn_result['dist'].mean(0) + mean_pn_dist = results[1]['dist'].mean(0) self._add_ranked_scalars( f'Embedding/ParNet distance {postfix}', mean_pn_dist, self.num_pos_pairs, self.num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embed_c.mean(0) + mean_hpm_embedding = embeddings[0].mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( f'Embedding/HPM norm {postfix}', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embed_p.mean(0) + mean_pa_embedding = embeddings[1].mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( f'Embedding/PartNet norm {postfix}', mean_pa_norm, -- cgit v1.2.3