diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-19 22:39:49 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-19 22:39:49 +0800 |
commit | d12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (patch) | |
tree | 71b5209ce4b5cfb1d09b89fe133028bbfa481dc9 /models/rgb_part_net.py | |
parent | 4aa9044122878a8e2b887a8b170c036983431559 (diff) |
New branch with auto-encoder only
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 63 |
1 files changed, 3 insertions, 60 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 67acac3..f18d675 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -2,9 +2,6 @@ import torch import torch.nn as nn from models.auto_encoder import AutoEncoder -from models.hpm import HorizontalPyramidMatching -from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -14,80 +11,26 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, - hpm_scales: tuple[int, ...] = (1, 2, 4), - hpm_use_avg_pool: bool = True, - hpm_use_max_pool: bool = True, - fpfe_feature_channels: int = 32, - fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - fpfe_halving: tuple[int, ...] = (0, 2, 3), - tfa_squeeze_ratio: int = 4, - tfa_num_parts: int = 16, - embedding_dims: int = 256, - triplet_margins: tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn = PartNet( - ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, - fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts - ) - out_channels = self.pn.tfa_in_channels - self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, out_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool - ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - out_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) - - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - - def fc(self, x): - return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - x_c = self.hpm(x_c) - # p, n, c - - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) - if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + losses = torch.stack(losses) return losses, images else: - return x.unsqueeze(1).view(-1) + return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() |