diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 121 |
1 files changed, 43 insertions, 78 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 4a82da3..d3f8ade 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F from models.auto_encoder import AutoEncoder -from models.hpm import HorizontalPyramidMatching -from models.part_net import PartNet class RGBPartNet(nn.Module): @@ -13,12 +12,6 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_scales: tuple[int, ...] = (1, 2, 4), - hpm_use_avg_pool: bool = True, - hpm_use_max_pool: bool = True, - tfa_squeeze_ratio: int = 4, - tfa_num_parts: int = 16, - embedding_dims: tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() @@ -29,94 +22,66 @@ class RGBPartNet(nn.Module): self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn_in_channels = ae_feature_channels * 2 - self.hpm = HorizontalPyramidMatching( - self.pn_in_channels, embedding_dims[0], hpm_scales, - hpm_use_avg_pool, hpm_use_max_pool - ) - self.pn = PartNet(self.pn_in_channels, embedding_dims[1], - tfa_num_parts, tfa_squeeze_ratio) - - self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) - - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - x_c = self.hpm(x_c) - # p, n, d - - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(x_p) - # p, n, d + losses, features, images = self._disentangle(x_c1, x_c2) if self.training: - return x_c, x_p, ae_losses, images + losses = torch.stack(losses) + return losses, features, images else: - return x_c, x_p + return features def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() - device = x_c1_t2.device if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - # Decode features - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) + f_a = f_a_.view(n, t, -1) + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) i_a, i_c, i_p = None, None, None if self.image_log_on: with torch.no_grad(): - i_a = self._decode_appr_feature(f_a_, n, t, device) - # Continue decoding canonical features - i_c = self.ae.decoder.trans_conv3(x_c) - i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p_ = self.ae.decoder.trans_conv3(x_p_) - i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) + x_a, i_a = self._separate_decode( + f_a.mean(1), + torch.zeros_like(f_c[:, 0, :]), + torch.zeros_like(f_p[:, 0, :]) + ) + x_c, i_c = self._separate_decode( + torch.zeros_like(f_a[:, 0, :]), + f_c.mean(1), + torch.zeros_like(f_p[:, 0, :]), + ) + x_p_, i_p_ = self._separate_decode( + torch.zeros_like(f_a_), + torch.zeros_like(f_c_), + f_p_ + ) + x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_) i_p = i_p_.view(n, t, c, h, w) - return (x_c, x_p), losses, (i_a, i_c, i_p) + return losses, (x_a, x_c, x_p), (i_a, i_c, i_p) else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - return (x_c, x_p), None, None - - def _decode_appr_feature(self, f_a_, n, t, device): - # Decode appearance features - f_a = f_a_.view(n, t, -1) - x_a = self.ae.decoder( - f_a.mean(1), - torch.zeros((n, self.f_c_dim), device=device), - torch.zeros((n, self.f_p_dim), device=device) - ) - return x_a - - def _decode_cano_feature(self, f_c_, n, t, device): - # Decode average canonical features to higher dimension - f_c = f_c_.view(n, t, -1) - x_c = self.ae.decoder( - torch.zeros((n, self.f_a_dim), device=device), - f_c.mean(1), - torch.zeros((n, self.f_p_dim), device=device), - is_feature_map=True - ) - return x_c - - def _decode_pose_feature(self, f_p_, n, t, device): - # Decode pose features to images - x_p_ = self.ae.decoder( - torch.zeros((n * t, self.f_a_dim), device=device), - torch.zeros((n * t, self.f_c_dim), device=device), - f_p_, - is_feature_map=True + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) + return (f_c, f_p), None, None + + def _separate_decode(self, f_a, f_c, f_p): + x_1 = torch.cat((f_a, f_c, f_p), dim=1) + x_1 = self.ae.decoder.fc(x_1).view( + -1, + self.ae.decoder.feature_channels * 8, + self.ae.decoder.h_0, + self.ae.decoder.w_0 ) - return x_p_ + x_1 = F.relu(x_1, inplace=True) + x_2 = self.ae.decoder.trans_conv1(x_1) + x_3 = self.ae.decoder.trans_conv2(x_2) + x_4 = self.ae.decoder.trans_conv3(x_3) + image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4)) + x = (x_1, x_2, x_3, x_4) + return x, image |