diff options
Diffstat (limited to 'models/hpm.py')
-rw-r--r-- | models/hpm.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/models/hpm.py b/models/hpm.py index 85a4e58..6dd58f4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module): self.backbone = resnet50(pretrained=True) self.in_channels = self.backbone.layer4[-1].conv1.in_channels - self.pyramids = [ + self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales - ] + ]) def _make_pyramid(self, scale: int, **kwargs): - pyramid = [HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale)] + pyramid = nn.ModuleList([ + HorizontalPyramidPooling(self.in_channels, + self.out_channels, + use_avg_pool=self.use_avg_pool, + use_max_pool=self.use_max_pool, + **kwargs) + for _ in range(scale) + ]) return pyramid def forward(self, x): |