summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-03 17:17:21 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-03 17:17:21 +0800
commit425af1da453203a3d5b526d3e30af9c9f9faaa72 (patch)
tree0847f7de481deb183842ae03a6061a5b6a66e26c /models/rgb_part_net.py
parent2ac1787e4580521848460215e6b06f4bb1648f06 (diff)
Add separate fully connected layers
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index ba5a00e..02345d6 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -22,7 +22,9 @@ class RGBPartNet(nn.Module):
fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
fpfe_halving: tuple[int, ...] = (0, 2, 3),
tfa_squeeze_ratio: int = 4,
- tfa_num_part: int = 16,
+ tfa_num_parts: int = 16,
+ embedding_dims: int = 256,
+ triplet_margin: int = 0.2
):
super().__init__()
self.ae = AutoEncoder(
@@ -30,15 +32,22 @@ class RGBPartNet(nn.Module):
)
self.pn = PartNet(
ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
- fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part
+ fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts
)
+ out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales,
+ ae_feature_channels * 8, out_channels, hpm_scales,
hpm_use_avg_pool, hpm_use_max_pool
)
+ total_parts = sum(hpm_scales) + tfa_num_parts
+ empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
+ self.fc_mat = nn.Parameter(nn.init.xavier_uniform_(empty_fc))
# TODO Weight inti here
+ def fc(self, x):
+ return torch.matmul(x, self.fc_mat)
+
def forward(self, x_c1, x_c2, y=None):
# Step 0: Swap batch_size and time dimensions for next step
# n, t, c, h, w
@@ -60,8 +69,9 @@ class RGBPartNet(nn.Module):
x_p = self.pn(x_p_c1)
# p, n, c
- # Step 3: Cat feature map together and calculate losses
+ # Step 3: Cat feature map together and fc
x = torch.cat((x_c, x_p))
+ x = self.fc(x)
if self.training:
# TODO Implement Batch All triplet loss function