diff options
Diffstat (limited to 'models/part_net.py')
-rw-r--r-- | models/part_net.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/models/part_net.py b/models/part_net.py index fbf1c88..66e61fc 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module): params = (in_channels, out_channels, kernel_sizes, paddings, halving, use_pools) - self.fconv_blocks = [FocalConv2dBlock(*_params) - for _params in zip(*params)] + self.fconv_blocks = nn.ModuleList([ + FocalConv2dBlock(*_params) for _params in zip(*params) + ]) def forward(self, x): # Flatten frames in all batches @@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module): def forward(self, x): x = x.transpose(2, 3) p, n, c, t = x.size() - feature = x.split(1, dim=0).squeeze(0) + feature = x.split(1, dim=0) + feature = [f.squeeze(0) for f in feature] x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid @@ -142,7 +144,7 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.squeeze() for x_ in x] + x = [x_.view(n, t, c, -1) for x_ in x] x = torch.stack(x) # p, n, t, c |