diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 10:51:38 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 10:51:38 +0800 |
commit | 07de5a9797a4f53779fe43c14ee51bdc4e248463 (patch) | |
tree | cf0c7f57657766bae005c8182e07c7c291482f5e /models | |
parent | 2fddfd8f99f86f389117541421e457272f216d0b (diff) |
Bug Fixes in HPM and PartNet
1. Register list of torch.nn.Module to the network using torch.nn.ModuleList
2. Fix operation error in squeeze list of tensor
3. Replace squeeze with view in HP in case batch size is 1
Diffstat (limited to 'models')
-rw-r--r-- | models/hpm.py | 18 | ||||
-rw-r--r-- | models/part_net.py | 10 |
2 files changed, 16 insertions, 12 deletions
diff --git a/models/hpm.py b/models/hpm.py index 85a4e58..6dd58f4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module): self.backbone = resnet50(pretrained=True) self.in_channels = self.backbone.layer4[-1].conv1.in_channels - self.pyramids = [ + self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales - ] + ]) def _make_pyramid(self, scale: int, **kwargs): - pyramid = [HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale)] + pyramid = nn.ModuleList([ + HorizontalPyramidPooling(self.in_channels, + self.out_channels, + use_avg_pool=self.use_avg_pool, + use_max_pool=self.use_max_pool, + **kwargs) + for _ in range(scale) + ]) return pyramid def forward(self, x): diff --git a/models/part_net.py b/models/part_net.py index fbf1c88..66e61fc 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module): params = (in_channels, out_channels, kernel_sizes, paddings, halving, use_pools) - self.fconv_blocks = [FocalConv2dBlock(*_params) - for _params in zip(*params)] + self.fconv_blocks = nn.ModuleList([ + FocalConv2dBlock(*_params) for _params in zip(*params) + ]) def forward(self, x): # Flatten frames in all batches @@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module): def forward(self, x): x = x.transpose(2, 3) p, n, c, t = x.size() - feature = x.split(1, dim=0).squeeze(0) + feature = x.split(1, dim=0) + feature = [f.squeeze(0) for f in feature] x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid @@ -142,7 +144,7 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.squeeze() for x_ in x] + x = [x_.view(n, t, c, -1) for x_ in x] x = torch.stack(x) # p, n, t, c |