summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:56:17 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:56:17 +0800
commitc74df416b00f837ba051f3947be92f76e7afbd88 (patch)
tree02983df94008bbb427c2066c5f619e0ffdefe1c5 /models/rgb_part_net.py
parent1b8d1614168ce6590c5e029c7f1007ac9b17048c (diff)
Code refactoring
1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py32
1 files changed, 10 insertions, 22 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 8a0f3a7..c38a567 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,39 +13,31 @@ class RGBPartNet(nn.Module):
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
- embedding_dims: int = 256,
+ embedding_dims: tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn_in_channels = ae_feature_channels * 2
- self.pn = PartNet(
- self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
- )
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
+ self.pn_in_channels, embedding_dims[0], hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
)
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- self.pn_in_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
+ tfa_num_parts, tfa_squeeze_ratio)
- def fc(self, x):
- return x @ self.fc_mat
+ self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
@@ -55,21 +47,17 @@ class RGBPartNet(nn.Module):
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
x_c = self.hpm(x_c)
- # p, n, c
+ # p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
+ # p, n, d
if self.training:
- return x, ae_losses, images
+ return x_c, x_p, ae_losses, images
else:
- return x.unsqueeze(1).view(-1)
+ return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()