diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 10:51:38 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 10:51:38 +0800 |
commit | 07de5a9797a4f53779fe43c14ee51bdc4e248463 (patch) | |
tree | cf0c7f57657766bae005c8182e07c7c291482f5e /models/hpm.py | |
parent | 2fddfd8f99f86f389117541421e457272f216d0b (diff) |
Bug Fixes in HPM and PartNet
1. Register list of torch.nn.Module to the network using torch.nn.ModuleList
2. Fix operation error in squeeze list of tensor
3. Replace squeeze with view in HP in case batch size is 1
Diffstat (limited to 'models/hpm.py')
-rw-r--r-- | models/hpm.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/models/hpm.py b/models/hpm.py index 85a4e58..6dd58f4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module): self.backbone = resnet50(pretrained=True) self.in_channels = self.backbone.layer4[-1].conv1.in_channels - self.pyramids = [ + self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales - ] + ]) def _make_pyramid(self, scale: int, **kwargs): - pyramid = [HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale)] + pyramid = nn.ModuleList([ + HorizontalPyramidPooling(self.in_channels, + self.out_channels, + use_avg_pool=self.use_avg_pool, + use_max_pool=self.use_max_pool, + **kwargs) + for _ in range(scale) + ]) return pyramid def forward(self, x): |