summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 85a4e58..6dd58f4 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module):
self.backbone = resnet50(pretrained=True)
self.in_channels = self.backbone.layer4[-1].conv1.in_channels
- self.pyramids = [
+ self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
- ]
+ ])
def _make_pyramid(self, scale: int, **kwargs):
- pyramid = [HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
- for _ in range(scale)]
+ pyramid = nn.ModuleList([
+ HorizontalPyramidPooling(self.in_channels,
+ self.out_channels,
+ use_avg_pool=self.use_avg_pool,
+ use_max_pool=self.use_max_pool,
+ **kwargs)
+ for _ in range(scale)
+ ])
return pyramid
def forward(self, x):