summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-31 10:51:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-31 10:51:38 +0800
commit07de5a9797a4f53779fe43c14ee51bdc4e248463 (patch)
treecf0c7f57657766bae005c8182e07c7c291482f5e /models/hpm.py
parent2fddfd8f99f86f389117541421e457272f216d0b (diff)
Bug Fixes in HPM and PartNet
1. Register list of torch.nn.Module to the network using torch.nn.ModuleList 2. Fix operation error in squeeze list of tensor 3. Replace squeeze with view in HP in case batch size is 1
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 85a4e58..6dd58f4 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module):
self.backbone = resnet50(pretrained=True)
self.in_channels = self.backbone.layer4[-1].conv1.in_channels
- self.pyramids = [
+ self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
- ]
+ ])
def _make_pyramid(self, scale: int, **kwargs):
- pyramid = [HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
- for _ in range(scale)]
+ pyramid = nn.ModuleList([
+ HorizontalPyramidPooling(self.in_channels,
+ self.out_channels,
+ use_avg_pool=self.use_avg_pool,
+ use_max_pool=self.use_max_pool,
+ **kwargs)
+ for _ in range(scale)
+ ])
return pyramid
def forward(self, x):