diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 20:12:33 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 20:12:33 +0800 |
commit | e83ae0bcb5c763636fd522c2712a3c8aef558f3c (patch) | |
tree | b80da057e4c4574ea95fa9f3d3b2fe8c999e3440 /models/rgb_part_net.py | |
parent | f2f7713efa03a877bc96ced37314b4c4a6dc1963 (diff) | |
parent | 2ea916b2a963eae7d47151b41c8c78a578c402e2 (diff) |
Merge branch 'master' into data_parallel
# Conflicts:
# models/auto_encoder.py
# models/model.py
# models/rgb_part_net.py
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 32 |
1 files changed, 10 insertions, 22 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 7785bb7..fdeed17 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,39 +13,31 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, - embedding_dims: int = 256, + embedding_dims: tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn_in_channels = ae_feature_channels * 2 - self.pn = PartNet( - self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts - ) self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool + self.pn_in_channels, embedding_dims[0], hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - self.pn_in_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) + self.pn = PartNet(self.pn_in_channels, embedding_dims[1], + tfa_num_parts, tfa_squeeze_ratio) - def fc(self, x): - return x @ self.fc_mat + self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement @@ -55,21 +47,17 @@ class RGBPartNet(nn.Module): # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w x_c = self.hpm(x_c) - # p, n, c + # p, n, d # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) + # p, n, d if self.training: - return x.transpose(0, 1), images, f_loss + return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss else: - return x.unsqueeze(1).view(-1) + return torch.cat((x_c, x_p)).unsqueeze(1).view(-1) def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() |