summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/models/part_net.py b/models/part_net.py
index fbf1c88..66e61fc 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module):
params = (in_channels, out_channels, kernel_sizes,
paddings, halving, use_pools)
- self.fconv_blocks = [FocalConv2dBlock(*_params)
- for _params in zip(*params)]
+ self.fconv_blocks = nn.ModuleList([
+ FocalConv2dBlock(*_params) for _params in zip(*params)
+ ])
def forward(self, x):
# Flatten frames in all batches
@@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module):
def forward(self, x):
x = x.transpose(2, 3)
p, n, c, t = x.size()
- feature = x.split(1, dim=0).squeeze(0)
+ feature = x.split(1, dim=0)
+ feature = [f.squeeze(0) for f in feature]
x = x.view(-1, c, t)
# MTB1: ConvNet1d & Sigmoid
@@ -142,7 +144,7 @@ class PartNet(nn.Module):
split_size = h // self.num_part
x = x.split(split_size, dim=3)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.squeeze() for x_ in x]
+ x = [x_.view(n, t, c, -1) for x_ in x]
x = torch.stack(x)
# p, n, t, c