diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:11:25 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:25:42 +0800 |
commit | 99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (patch) | |
tree | a4ccbd08a7155e90df63aba60eb93ab2b7969c9b /models/hpm.py | |
parent | 507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (diff) |
Code refactoring, modifications and new features
1. Decode features outside of auto-encoder
2. Turn off HPM 1x1 conv by default
3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8`
4. Use mean of canonical embeddings instead of mean of static features
5. Calculate static and dynamic loss separately
6. Calculate mean of parts in triplet loss instead of sum of parts
7. Add switch to log disentangled images
8. Change default configuration
Diffstat (limited to 'models/hpm.py')
-rw-r--r-- | models/hpm.py | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/models/hpm.py b/models/hpm.py index 66503e3..9879cfb 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,14 +9,16 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, + use_1x1conv: bool = False, scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, - use_max_pool: bool = True, + use_max_pool: bool = False, **kwargs ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels + self.use_1x1conv = use_1x1conv self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool @@ -29,6 +31,7 @@ class HorizontalPyramidMatching(nn.Module): pyramid = nn.ModuleList([ HorizontalPyramidPooling(self.in_channels, self.out_channels, + use_1x1conv=self.use_1x1conv, use_avg_pool=self.use_avg_pool, use_max_pool=self.use_max_pool, **kwargs) @@ -37,23 +40,16 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): - # Flatten canonical features in all batches - t, n, c, h, w = x.size() - x = x.view(t * n, c, h, w) - + n, c, h, w = x.size() feature = [] - for pyramid_index, pyramid in enumerate(self.pyramids): - h_per_hpp = h // self.scales[pyramid_index] + for scale, pyramid in zip(self.scales, self.pyramids): + h_per_hpp = h // scale for hpp_index, hpp in enumerate(pyramid): h_filter = torch.arange(hpp_index * h_per_hpp, (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(t * n, -1) + x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) - - # Unfold frames to original batch - p, _, c = x.size() - x = x.view(p, t, n, c) return x |