diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-06 21:19:26 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-06 21:19:26 +0800 |
commit | 7f6c65fa954a43c9d248219525cad35fe1b8a046 (patch) | |
tree | 67fb188dd59b517a02ac25cea630e9f094522afb /models/rgb_part_net.py | |
parent | b9f35fbe7d78b3c478086ea26c2a76f72ce35687 (diff) |
Network modification
This commit change the auto-encoder by removing fc and optimizing latent space features
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 104 |
1 files changed, 36 insertions, 68 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 4a82da3..ecc38c0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -12,7 +12,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, - f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + f_a_c_p_dims: tuple[int, int, int] = (192, 192, 128), hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, @@ -23,100 +23,68 @@ class RGBPartNet(nn.Module): ): super().__init__() self.h, self.w = ae_in_size - (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn_in_channels = ae_feature_channels * 2 self.hpm = HorizontalPyramidMatching( - self.pn_in_channels, embedding_dims[0], hpm_scales, + f_a_c_p_dims[1], embedding_dims[0], hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) - self.pn = PartNet(self.pn_in_channels, embedding_dims[1], - tfa_num_parts, tfa_squeeze_ratio) + self.pn = PartNet( + f_a_c_p_dims[2], embedding_dims[1], tfa_num_parts, tfa_squeeze_ratio + ) self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) + (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w - x_c = self.hpm(x_c) + f_c_mean = f_c.mean(1) + x_c = self.hpm(f_c_mean) # p, n, d # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w - x_p = self.pn(x_p) + x_p = self.pn(f_p) # p, n, d if self.training: - return x_c, x_p, ae_losses, images + i_a, i_c, i_p = None, None, None + if self.image_log_on: + f_a_mean = f_a.mean(1) + i_a = self.ae.decoder( + f_a_mean, + torch.zeros_like(f_c_mean), + torch.zeros_like(f_p[:, 0]) + ) + i_c = self.ae.decoder( + torch.zeros_like(f_a_mean), + f_c_mean, + torch.zeros_like(f_p[:, 0]) + ) + f_p_size = f_p.size() + i_p = self.ae.decoder( + torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:], + device=f_a.device), + torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:], + device=f_c.device), + f_p.view(-1, *f_p_size[2:]) + ).view(x_c1.size()) + return x_c, x_p, ae_losses, (i_a, i_c, i_p) else: return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): - n, t, c, h, w = x_c1_t2.size() - device = x_c1_t2.device if self.training: - x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] - ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - # Decode features - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - - i_a, i_c, i_p = None, None, None - if self.image_log_on: - with torch.no_grad(): - i_a = self._decode_appr_feature(f_a_, n, t, device) - # Continue decoding canonical features - i_c = self.ae.decoder.trans_conv3(x_c) - i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p_ = self.ae.decoder.trans_conv3(x_p_) - i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) - i_p = i_p_.view(n, t, c, h, w) - - return (x_c, x_p), losses, (i_a, i_c, i_p) - + x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :] + features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + return features, losses else: # evaluating - f_c_, f_p_ = self.ae(x_c1_t2) - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - return (x_c, x_p), None, None - - def _decode_appr_feature(self, f_a_, n, t, device): - # Decode appearance features - f_a = f_a_.view(n, t, -1) - x_a = self.ae.decoder( - f_a.mean(1), - torch.zeros((n, self.f_c_dim), device=device), - torch.zeros((n, self.f_p_dim), device=device) - ) - return x_a - - def _decode_cano_feature(self, f_c_, n, t, device): - # Decode average canonical features to higher dimension - f_c = f_c_.view(n, t, -1) - x_c = self.ae.decoder( - torch.zeros((n, self.f_a_dim), device=device), - f_c.mean(1), - torch.zeros((n, self.f_p_dim), device=device), - is_feature_map=True - ) - return x_c - - def _decode_pose_feature(self, f_p_, n, t, device): - # Decode pose features to images - x_p_ = self.ae.decoder( - torch.zeros((n * t, self.f_a_dim), device=device), - torch.zeros((n * t, self.f_c_dim), device=device), - f_p_, - is_feature_map=True - ) - return x_p_ + features = self.ae(x_c1_t2) + return features, None |