summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-31 10:51:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-31 10:51:38 +0800
commit07de5a9797a4f53779fe43c14ee51bdc4e248463 (patch)
treecf0c7f57657766bae005c8182e07c7c291482f5e /models/part_net.py
parent2fddfd8f99f86f389117541421e457272f216d0b (diff)
Bug Fixes in HPM and PartNet
1. Register list of torch.nn.Module to the network using torch.nn.ModuleList 2. Fix operation error in squeeze list of tensor 3. Replace squeeze with view in HP in case batch size is 1
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/models/part_net.py b/models/part_net.py
index fbf1c88..66e61fc 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -24,8 +24,9 @@ class FrameLevelPartFeatureExtractor(nn.Module):
params = (in_channels, out_channels, kernel_sizes,
paddings, halving, use_pools)
- self.fconv_blocks = [FocalConv2dBlock(*_params)
- for _params in zip(*params)]
+ self.fconv_blocks = nn.ModuleList([
+ FocalConv2dBlock(*_params) for _params in zip(*params)
+ ])
def forward(self, x):
# Flatten frames in all batches
@@ -80,7 +81,8 @@ class TemporalFeatureAggregator(nn.Module):
def forward(self, x):
x = x.transpose(2, 3)
p, n, c, t = x.size()
- feature = x.split(1, dim=0).squeeze(0)
+ feature = x.split(1, dim=0)
+ feature = [f.squeeze(0) for f in feature]
x = x.view(-1, c, t)
# MTB1: ConvNet1d & Sigmoid
@@ -142,7 +144,7 @@ class PartNet(nn.Module):
split_size = h // self.num_part
x = x.split(split_size, dim=3)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.squeeze() for x_ in x]
+ x = [x_.view(n, t, c, -1) for x_ in x]
x = torch.stack(x)
# p, n, t, c