summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:59:18 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:59:18 +0800
commit2a7d3c04eab1f3c2e5306d1597399582229a87e5 (patch)
tree060bbd3d0b9d1f3823219225097fb4d74eb311fe /models/hpm.py
parent39fb3e19601aaccd572ea023b117543b9d791b56 (diff)
parentd63b267dd15388dd323d9b8672cdb9461b96c885 (diff)
Merge branch 'python3.8' into python3.7
# Conflicts: # utils/configuration.py
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/models/hpm.py b/models/hpm.py
index b49be3a..8320569 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -11,32 +11,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: Tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -54,4 +48,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x