summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:30:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:34:00 +0800
commit30b475c0a27e0f848743abf0f909607defc6a3ee (patch)
treeaaab163d3d76a835c32ce5014ce62637550d0b0d /models/hpm.py
parent3d8fc322623ba61610fd206b9f52b406e85cae61 (diff)
parente83ae0bcb5c763636fd522c2712a3c8aef558f3c (diff)
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts: # models/hpm.py # models/model.py # models/rgb_part_net.py # utils/configuration.py # utils/triplet_loss.py
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/models/hpm.py b/models/hpm.py
index b49be3a..8320569 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -11,32 +11,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: Tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -54,4 +48,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x