diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 21:00:01 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 21:00:01 +0800 |
commit | 86421e899c87976d8559795979415e3fae2bd7ed (patch) | |
tree | c4bfa828da64cc71b43ed12d3fd78be8a8930181 /models/rgb_part_net.py | |
parent | 57275a210b93f9ffd30a53e22c3c28f49f228d14 (diff) |
Implement some parts of RGB-GaitPart wrapper
1. Triplet loss function and weight init function haven't been implement yet
2. Tuplize features returned by auto-encoder for later unpack
3. Correct comment error in auto-encoder
4. Swap batch_size dim and time dim in HPM and PartNet in case of redundant transpose
5. Find backbone problems in HPM and disable it temporarily
6. Make feature structure by HPM consistent to that by PartNet
7. Fix average pooling dimension issue and incorrect view change in HP
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py new file mode 100644 index 0000000..9768dec --- /dev/null +++ b/models/rgb_part_net.py @@ -0,0 +1,98 @@ +import random + +import torch +import torch.nn as nn + +from models import AutoEncoder, HorizontalPyramidMatching, PartNet + + +class RGBPartNet(nn.Module): + def __init__( + self, + num_class: int = 74, + ae_in_channels: int = 3, + ae_feature_channels: int = 64, + f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + hpm_scales: tuple[int, ...] = (1, 2, 4, 8), + hpm_use_avg_pool: bool = True, + hpm_use_max_pool: bool = True, + fpfe_feature_channels: int = 32, + fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), + fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), + fpfe_halving: tuple[int, ...] = (0, 2, 3), + tfa_squeeze_ratio: int = 4, + tfa_num_part: int = 16, + ): + super().__init__() + self.ae = AutoEncoder( + num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims + ) + self.pn = PartNet( + ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, + fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part + ) + self.hpm = HorizontalPyramidMatching( + ae_in_channels, self.pn.tfa_in_channels, hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool + ) + + self.mse_loss = nn.MSELoss() + + # TODO Weight inti here + + def pose_sim_loss(self, f_p_c1: torch.Tensor, + f_p_c2: torch.Tensor) -> torch.Tensor: + f_p_c1_mean = f_p_c1.mean(dim=0) + f_p_c2_mean = f_p_c2.mean(dim=0) + return self.mse_loss(f_p_c1_mean, f_p_c2_mean).item() + + def forward(self, x_c1, x_c2, y): + # Step 0: Swap batch_size and time dimensions for next step + # n, t, c, h, w + x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) + + # Step 1: Disentanglement + # t, n, c, h, w + num_frames = len(x_c1) + f_c_c1, f_p_c1, f_p_c2 = [], [], [] + xrecon_loss, cano_cons_loss = 0, 0 + for t2 in range(num_frames): + t1 = random.randrange(num_frames) + output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output + (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 + # Features for next step + f_c_c1.append(f_c_c1_t2) + f_p_c1.append(f_p_c1_t2) + # Losses per time step + f_p_c2.append(f_p_c2_t2) + xrecon_loss += xrecon_loss_t2 + cano_cons_loss += cano_cons_loss_t2 + f_c_c1 = torch.stack(f_c_c1) + f_p_c1 = torch.stack(f_p_c1) + + # Step 2.a: HPM & Static Gait Feature Aggregation + # t, n, c, h, w + x_c = self.hpm(f_c_c1) + # p, t, n, c + x_c = x_c.mean(dim=1) + # p, n, c + + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # t, n, c, h, w + x_p = self.pn(f_p_c1) + # p, n, c + + # Step 3: Cat feature map together and calculate losses + x = torch.cat(x_c, x_p) + # Losses + xrecon_loss /= num_frames + f_p_c2 = torch.stack(f_p_c2) + pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) + cano_cons_loss /= num_frames + # TODO Implement Batch All triplet loss function + batch_all_triplet_loss = 0 + loss = (xrecon_loss + pose_sim_loss + cano_cons_loss + + batch_all_triplet_loss) + + return x, loss |