summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:12:33 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:12:33 +0800
commite83ae0bcb5c763636fd522c2712a3c8aef558f3c (patch)
treeb80da057e4c4574ea95fa9f3d3b2fe8c999e3440 /models/hpm.py
parentf2f7713efa03a877bc96ced37314b4c4a6dc1963 (diff)
parent2ea916b2a963eae7d47151b41c8c78a578c402e2 (diff)
Merge branch 'master' into data_parallel
# Conflicts: # models/auto_encoder.py # models/model.py # models/rgb_part_net.py
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 9879cfb..8186b20 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x