summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:30:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:34:00 +0800
commit30b475c0a27e0f848743abf0f909607defc6a3ee (patch)
treeaaab163d3d76a835c32ce5014ce62637550d0b0d /models/rgb_part_net.py
parent3d8fc322623ba61610fd206b9f52b406e85cae61 (diff)
parente83ae0bcb5c763636fd522c2712a3c8aef558f3c (diff)
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts: # models/hpm.py # models/model.py # models/rgb_part_net.py # utils/configuration.py # utils/triplet_loss.py
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py32
1 files changed, 10 insertions, 22 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 2853571..81f198e 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -15,39 +15,31 @@ class RGBPartNet(nn.Module):
ae_in_size: Tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
hpm_scales: Tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
- embedding_dims: int = 256,
+ embedding_dims: Tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn_in_channels = ae_feature_channels * 2
- self.pn = PartNet(
- self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
- )
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
+ self.pn_in_channels, embedding_dims[0], hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
)
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- self.pn_in_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
+ tfa_num_parts, tfa_squeeze_ratio)
- def fc(self, x):
- return x @ self.fc_mat
+ self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
@@ -57,21 +49,17 @@ class RGBPartNet(nn.Module):
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
x_c = self.hpm(x_c)
- # p, n, c
+ # p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
+ # p, n, d
if self.training:
- return x.transpose(0, 1), images, f_loss
+ return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss
else:
- return x.unsqueeze(1).view(-1)
+ return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()