From 07de5a9797a4f53779fe43c14ee51bdc4e248463 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 31 Dec 2020 10:51:38 +0800 Subject: Bug Fixes in HPM and PartNet 1. Register list of torch.nn.Module to the network using torch.nn.ModuleList 2. Fix operation error in squeeze list of tensor 3. Replace squeeze with view in HP in case batch size is 1 --- models/hpm.py | 18 ++++++++++-------- models/part_net.py | 10 ++++++---- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/models/hpm.py b/models/hpm.py index 85a4e58..6dd58f4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module): self.backbone = resnet50(pretrained=True) self.in_channels = self.backbone.layer4[-1].conv1.in_channels - self.pyramids = [ + self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales - ] + ]) def _make_pyramid(self, scale: int, **kwargs): - pyramid = [HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale)] + pyramid = nn.ModuleList([ + HorizontalPyramidPooling(self.in_channels, + self.out_channels, + use_avg_pool=self.use_avg_pool, + use_max_pool=self.use_max_pool, + **kwargs) + for _ in range(scale) + ]) return pyramid def forward(self, x): diff --git a/models/part_net.py b/models/part_net.py index fbf1c88..66e61fc 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module): params = (in_channels, out_channels, kernel_sizes, paddings, halving, use_pools) - self.fconv_blocks = [FocalConv2dBlock(*_params) - for _params in zip(*params)] + self.fconv_blocks = nn.ModuleList([ + FocalConv2dBlock(*_params) for _params in zip(*params) + ]) def forward(self, x): # Flatten frames in all batches @@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module): def forward(self, x): x = x.transpose(2, 3) p, n, c, t = x.size() - feature = x.split(1, dim=0).squeeze(0) + feature = x.split(1, dim=0) + feature = [f.squeeze(0) for f in feature] x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid @@ -142,7 +144,7 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.squeeze() for x_ in x] + x = [x_.view(n, t, c, -1) for x_ in x] x = torch.stack(x) # p, n, t, c -- cgit v1.2.3