diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:31:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:31:52 +0800 |
commit | d380e04df37593e414bd5641db100613fb2ad882 (patch) | |
tree | 1e3b3ea55a464d59d790711372bbca42cb203d0a /models/hpm.py | |
parent | a040400d7caa267d4bfbe8e5520568806f92b3d4 (diff) | |
parent | 99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/hpm.py
# models/layers.py
# models/model.py
# models/rgb_part_net.py
# utils/configuration.py
Diffstat (limited to 'models/hpm.py')
-rw-r--r-- | models/hpm.py | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/models/hpm.py b/models/hpm.py index 7505ed7..b49be3a 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -11,14 +11,16 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, + use_1x1conv: bool = False, scales: Tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, - use_max_pool: bool = True, + use_max_pool: bool = False, **kwargs ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels + self.use_1x1conv = use_1x1conv self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool @@ -31,6 +33,7 @@ class HorizontalPyramidMatching(nn.Module): pyramid = nn.ModuleList([ HorizontalPyramidPooling(self.in_channels, self.out_channels, + use_1x1conv=self.use_1x1conv, use_avg_pool=self.use_avg_pool, use_max_pool=self.use_max_pool, **kwargs) @@ -39,23 +42,16 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): - # Flatten canonical features in all batches - t, n, c, h, w = x.size() - x = x.view(t * n, c, h, w) - + n, c, h, w = x.size() feature = [] - for pyramid_index, pyramid in enumerate(self.pyramids): - h_per_hpp = h // self.scales[pyramid_index] + for scale, pyramid in zip(self.scales, self.pyramids): + h_per_hpp = h // scale for hpp_index, hpp in enumerate(pyramid): h_filter = torch.arange(hpp_index * h_per_hpp, (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(t * n, -1) + x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) - - # Unfold frames to original batch - p, _, c = x.size() - x = x.view(p, t, n, c) return x |