summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:56:17 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:56:17 +0800
commitc74df416b00f837ba051f3947be92f76e7afbd88 (patch)
tree02983df94008bbb427c2066c5f619e0ffdefe1c5 /models/hpm.py
parent1b8d1614168ce6590c5e029c7f1007ac9b17048c (diff)
Code refactoring
1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 9879cfb..8186b20 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x