diff options
-rw-r--r-- | models/auto_encoder.py | 14 | ||||
-rw-r--r-- | models/hpm.py | 20 | ||||
-rw-r--r-- | models/model.py | 99 | ||||
-rw-r--r-- | models/part_net.py | 20 | ||||
-rw-r--r-- | models/rgb_part_net.py | 37 |
5 files changed, 108 insertions, 82 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 96dfdb3..dc7843a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -134,25 +134,13 @@ class AutoEncoder(nn.Module): x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_) x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w) - xrecon_loss = torch.stack([ - F.mse_loss(x_c1_t2[:, i], x_c1_t2_pred[:, i]) - for i in range(t) - ]).sum() - f_c_c1_t1 = f_c_c1_t1_.view(f_size[1]) f_c_c2_t2 = f_c_c2_t2_.view(f_size[1]) - cano_cons_loss = torch.stack([ - F.mse_loss(f_c_c1_t1[:, i], f_c_c1_t2[:, i]) - + F.mse_loss(f_c_c1_t2[:, i], f_c_c2_t2[:, i]) - for i in range(t) - ]).mean() - f_p_c2_t2 = f_p_c2_t2_.view(f_size[2]) - pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)) return ( (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2), - (xrecon_loss, cano_cons_loss, pose_sim_loss * 10) + (x_c1_t2_pred, (f_c_c1_t1, f_c_c2_t2), f_p_c2_t2) ) else: # evaluating return f_a_c1_t2, f_c_c1_t2, f_p_c1_t2 diff --git a/models/hpm.py b/models/hpm.py index 8186b20..fa0f69e 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -33,8 +33,9 @@ class HorizontalPyramidMatching(nn.Module): ]) return pyramid - def forward(self, x): - n, c, h, w = x.size() + def _horizontal_pyramid_pool(self, x): + n, t, c, h, w = x.size() + x = x.view(n * t, c, h, w) feature = [] for scale, pyramid in zip(self.scales, self.pyramids): h_per_hpp = h // scale @@ -43,12 +44,23 @@ class HorizontalPyramidMatching(nn.Module): (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(n, -1) + x_slice = x_slice.view(n, t, c) feature.append(x_slice) x = torch.stack(feature) + return x + def forward(self, f_c1_t2, f_c1_t1=None, f_c2_t2=None): + # n, t, c, h, w + f_c1_t2_ = self._horizontal_pyramid_pool(f_c1_t2) + # p, n, t, c + x = f_c1_t2_.mean(2) # p, n, c x = x @ self.fc_mat # p, n, d - return x + if self.training: + f_c1_t1_ = self._horizontal_pyramid_pool(f_c1_t1) + f_c2_t2_ = self._horizontal_pyramid_pool(f_c2_t2) + return x, (f_c1_t2_, f_c1_t1_, f_c2_t2_) + else: + return x diff --git a/models/model.py b/models/model.py index 45067e6..6118bdf 100644 --- a/models/model.py +++ b/models/model.py @@ -267,11 +267,15 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding, f_loss, images = self.rgb_pn(x_c1, x_c2) + ae_losses = self._disentangling_loss(x_c1, f_loss) y = batch_c1['label'].to(self.device) - losses, hpm_result, pn_result = self._classification_loss( - embed_c, embed_p, ae_losses, y - ) + results = self._classification_loss(embedding, y) + losses = torch.stack(( + *ae_losses, + results[0]['loss'].mean(), + results[1]['loss'].mean() + )) loss = losses.sum() loss.backward() self.optimizer.step() @@ -282,9 +286,7 @@ class Model: 'Auto-encoder', 'HPM', 'PartNet' ), self.scheduler.get_last_lr())), self.curr_iter) # Other stats - self._write_stat( - 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses - ) + self._write_stat('Train', embedding, results, losses) # Write disentangled images if self.image_log_on and self.curr_iter % self.image_log_steps \ @@ -306,8 +308,8 @@ class Model: if self.curr_iter % 100 == 99: # Validation - embed_c = self._flatten_embedding(embed_c) - embed_p = self._flatten_embedding(embed_p) + embed_c = self._flatten_embedding(embedding[0]) + embed_p = self._flatten_embedding(embedding[1]) self._write_embedding('HPM Train', embed_c, x_c1, y) self._write_embedding('PartNet Train', embed_p, x_c1, y) @@ -316,18 +318,19 @@ class Model: x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) with torch.no_grad(): - embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2) + embedding, f_loss, images = self.rgb_pn(x_c1, x_c2) + ae_losses = self._disentangling_loss(x_c1, f_loss) y = batch_c1['label'].to(self.device) - losses, hpm_result, pn_result = self._classification_loss( - embed_c, embed_p, ae_losses, y - ) - loss = losses.sum() - - self._write_stat( - 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses - ) - embed_c = self._flatten_embedding(embed_c) - embed_p = self._flatten_embedding(embed_p) + results = self._classification_loss(embedding, y) + losses = torch.stack(( + *ae_losses, + results[0]['loss'].mean(), + results[1]['loss'].mean() + )) + + self._write_stat('Val', embedding, results, losses) + embed_c = self._flatten_embedding(embedding[0]) + embed_p = self._flatten_embedding(embedding[1]) self._write_embedding('HPM Val', embed_c, x_c1, y) self._write_embedding('PartNet Val', embed_p, x_c1, y) @@ -342,21 +345,39 @@ class Model: self.writer.close() - def _classification_loss(self, embed_c, embed_p, ae_losses, y): + @staticmethod + def _disentangling_loss(x_c1_t2, f_loss): + n, t, c, h, w = x_c1_t2.size() + x_c1_t2_pred = f_loss[0] + xrecon_loss = torch.stack([ + F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :]) + for i in range(t) + ]).sum() + cano_cons_loss = torch.stack([ + torch.stack([ + F.mse_loss(_f_c_c1_t1[:, i, :], _f_c_c1_t2[:, i, :]) + + F.mse_loss(_f_c_c1_t2[:, i, :], _f_c_c2_t2[:, i, :]) + for i in range(t) + ]).mean() + for _f_c_c1_t2, _f_c_c1_t1, _f_c_c2_t2 in zip(*f_loss[1]) + ]).sum() + pose_sim_loss = torch.stack([ + F.mse_loss(_f_p_c1_t2.mean(1), _f_p_c2_t2.mean(1)) + for _f_p_c1_t2, _f_p_c2_t2 in zip(*f_loss[2]) + ]).sum() + + return xrecon_loss, cano_cons_loss * 10, pose_sim_loss * 100 + + def _classification_loss(self, embedding, y): # Duplicate labels for each part y_triplet = y.repeat(self.rgb_pn.num_parts, 1) hpm_result = self.triplet_loss_hpm( - embed_c, y_triplet[:self.rgb_pn.hpm.num_parts] + embedding[0], y_triplet[:self.rgb_pn.hpm.num_parts] ) pn_result = self.triplet_loss_pn( - embed_p, y_triplet[self.rgb_pn.hpm.num_parts:] + embedding[1], y_triplet[self.rgb_pn.hpm.num_parts:] ) - losses = torch.stack(( - *ae_losses, - hpm_result.pop('loss').mean(), - pn_result.pop('loss').mean() - )) - return losses, hpm_result, pn_result + return hpm_result, pn_result def _write_embedding(self, tag, embed, x, y): frame = x[:, 0, :, :, :].cpu() @@ -374,11 +395,8 @@ class Model: def _flatten_embedding(self, embed): return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1) - def _write_stat( - self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses - ): + def _write_stat(self, postfix, embeddings, results, losses): # Write losses to TensorBoard - self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( 'Cross reconstruction loss', 'Canonical consistency loss', 'Pose similarity loss' @@ -388,30 +406,31 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if hpm_result['counts'] is not None and pn_result['counts'] is not None: + if results[0]['counts'] is not None \ + and results[1]['counts'] is not None: self.writer.add_scalars(f'Loss/non-zero counts {postfix}', { - 'HPM': hpm_result['counts'].mean(), - 'PartNet': pn_result['counts'].mean() + 'HPM': results[0]['counts'].mean(), + 'PartNet': results[1]['counts'].mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = hpm_result['dist'].mean(0) + mean_hpm_dist = results[0]['dist'].mean(0) self._add_ranked_scalars( f'Embedding/HPM distance {postfix}', mean_hpm_dist, self.num_pos_pairs, self.num_pairs, self.curr_iter ) - mean_pn_dist = pn_result['dist'].mean(0) + mean_pn_dist = results[1]['dist'].mean(0) self._add_ranked_scalars( f'Embedding/ParNet distance {postfix}', mean_pn_dist, self.num_pos_pairs, self.num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embed_c.mean(0) + mean_hpm_embedding = embeddings[0].mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( f'Embedding/HPM norm {postfix}', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embed_p.mean(0) + mean_pa_embedding = embeddings[1].mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( f'Embedding/PartNet norm {postfix}', mean_pa_norm, diff --git a/models/part_net.py b/models/part_net.py index f2236bf..65a2c14 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -127,23 +127,27 @@ class PartNet(nn.Module): torch.empty(num_parts, in_channels, embedding_dims) ) - def forward(self, x): + def _horizontal_pool(self, x): n, t, c, h, w = x.size() x = x.view(n * t, c, h, w) - # n * t x c x h x w - - # Horizontal Pooling - _, c, h, w = x.size() split_size = h // self.num_part x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] x = [x_.view(n, t, c) for x_ in x] x = torch.stack(x) + return x + def forward(self, f_c1_t2, f_c2_t2=None): + # n, t, c, h, w + f_c1_t2_ = self._horizontal_pool(f_c1_t2) # p, n, t, c - x = self.tfa(x) - + x = self.tfa(f_c1_t2_) # p, n, c x = x @ self.fc_mat # p, n, d - return x + + if self.training: + f_c2_t2_ = self._horizontal_pool(f_c2_t2) + return x, (f_c1_t2_, f_c2_t2_) + else: + return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index b0169e3..06cbf28 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -39,26 +39,26 @@ class RGBPartNet(nn.Module): self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2) + if self.training: + # Step 1: Disentanglement + # n, t, c, h, w + (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - f_c_mean = f_c.mean(1) - x_c = self.hpm(f_c_mean) - # p, n, d + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, t, c, h, w + x_c, f_c_loss = self.hpm(f_c, *f_loss[1]) + # p, n, d / p, n, t, c - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(f_p) - # p, n, d + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # n, t, c, h, w + x_p, f_p_loss = self.pn(f_p, f_loss[2]) + # p, n, d / p, n, t, c - if self.training: i_a, i_c, i_p = None, None, None if self.image_log_on: with torch.no_grad(): f_a_mean = f_a.mean(1) + f_c_mean = f_c.mean(1) i_a = self.ae.decoder( f_a_mean, torch.zeros_like(f_c_mean), @@ -77,15 +77,18 @@ class RGBPartNet(nn.Module): device=f_c.device), f_p.view(-1, *f_p_size[2:]) ).view(x_c1.size()) - return x_c, x_p, ae_losses, (i_a, i_c, i_p) - else: + return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p) + else: # Evaluating + f_c, f_p = self._disentangle(x_c1, x_c2) + x_c = self.hpm(f_c) + x_p = self.pn(f_p) return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :] - features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - return features, losses + features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + return features, f_loss else: # evaluating features = self.ae(x_c1_t2) return features, None |