diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index ba5a00e..02345d6 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -22,7 +22,9 @@ class RGBPartNet(nn.Module): fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), fpfe_halving: tuple[int, ...] = (0, 2, 3), tfa_squeeze_ratio: int = 4, - tfa_num_part: int = 16, + tfa_num_parts: int = 16, + embedding_dims: int = 256, + triplet_margin: int = 0.2 ): super().__init__() self.ae = AutoEncoder( @@ -30,15 +32,22 @@ class RGBPartNet(nn.Module): ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, - fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part + fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts ) + out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales, + ae_feature_channels * 8, out_channels, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) + total_parts = sum(hpm_scales) + tfa_num_parts + empty_fc = torch.empty(total_parts, out_channels, embedding_dims) + self.fc_mat = nn.Parameter(nn.init.xavier_uniform_(empty_fc)) # TODO Weight inti here + def fc(self, x): + return torch.matmul(x, self.fc_mat) + def forward(self, x_c1, x_c2, y=None): # Step 0: Swap batch_size and time dimensions for next step # n, t, c, h, w @@ -60,8 +69,9 @@ class RGBPartNet(nn.Module): x_p = self.pn(x_p_c1) # p, n, c - # Step 3: Cat feature map together and calculate losses + # Step 3: Cat feature map together and fc x = torch.cat((x_c, x_p)) + x = self.fc(x) if self.training: # TODO Implement Batch All triplet loss function |