diff options
-rw-r--r-- | models/hpm.py | 18 | ||||
-rw-r--r-- | models/part_net.py | 10 |
2 files changed, 16 insertions, 12 deletions
diff --git a/models/hpm.py b/models/hpm.py index 85a4e58..6dd58f4 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module): self.backbone = resnet50(pretrained=True) self.in_channels = self.backbone.layer4[-1].conv1.in_channels - self.pyramids = [ + self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales - ] + ]) def _make_pyramid(self, scale: int, **kwargs): - pyramid = [HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale)] + pyramid = nn.ModuleList([ + HorizontalPyramidPooling(self.in_channels, + self.out_channels, + use_avg_pool=self.use_avg_pool, + use_max_pool=self.use_max_pool, + **kwargs) + for _ in range(scale) + ]) return pyramid def forward(self, x): diff --git a/models/part_net.py b/models/part_net.py index fbf1c88..66e61fc 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module): params = (in_channels, out_channels, kernel_sizes, paddings, halving, use_pools) - self.fconv_blocks = [FocalConv2dBlock(*_params) - for _params in zip(*params)] + self.fconv_blocks = nn.ModuleList([ + FocalConv2dBlock(*_params) for _params in zip(*params) + ]) def forward(self, x): # Flatten frames in all batches @@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module): def forward(self, x): x = x.transpose(2, 3) p, n, c, t = x.size() - feature = x.split(1, dim=0).squeeze(0) + feature = x.split(1, dim=0) + feature = [f.squeeze(0) for f in feature] x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid @@ -142,7 +144,7 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.squeeze() for x_ in x] + x = [x_.view(n, t, c, -1) for x_ in x] x = torch.stack(x) # p, n, t, c |