summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/hpm.py18
-rw-r--r--models/part_net.py10
2 files changed, 16 insertions, 12 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 85a4e58..6dd58f4 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -23,17 +23,19 @@ class HorizontalPyramidMatching(nn.Module):
self.backbone = resnet50(pretrained=True)
self.in_channels = self.backbone.layer4[-1].conv1.in_channels
- self.pyramids = [
+ self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
- ]
+ ])
def _make_pyramid(self, scale: int, **kwargs):
- pyramid = [HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
- for _ in range(scale)]
+ pyramid = nn.ModuleList([
+ HorizontalPyramidPooling(self.in_channels,
+ self.out_channels,
+ use_avg_pool=self.use_avg_pool,
+ use_max_pool=self.use_max_pool,
+ **kwargs)
+ for _ in range(scale)
+ ])
return pyramid
def forward(self, x):
diff --git a/models/part_net.py b/models/part_net.py
index fbf1c88..66e61fc 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module):
params = (in_channels, out_channels, kernel_sizes,
paddings, halving, use_pools)
- self.fconv_blocks = [FocalConv2dBlock(*_params)
- for _params in zip(*params)]
+ self.fconv_blocks = nn.ModuleList([
+ FocalConv2dBlock(*_params) for _params in zip(*params)
+ ])
def forward(self, x):
# Flatten frames in all batches
@@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module):
def forward(self, x):
x = x.transpose(2, 3)
p, n, c, t = x.size()
- feature = x.split(1, dim=0).squeeze(0)
+ feature = x.split(1, dim=0)
+ feature = [f.squeeze(0) for f in feature]
x = x.view(-1, c, t)
# MTB1: ConvNet1d & Sigmoid
@@ -142,7 +144,7 @@ class PartNet(nn.Module):
split_size = h // self.num_part
x = x.split(split_size, dim=3)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.squeeze() for x_ in x]
+ x = [x_.view(n, t, c, -1) for x_ in x]
x = torch.stack(x)
# p, n, t, c