summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-09 21:29:06 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-09 21:29:06 +0800
commit045fdb1d8f381ef1dafdec33e87fc2b6736615e4 (patch)
tree644053b57c152e3eb8e7e885d87890991740834e /models/part_net.py
parent31e0294cdb2ffd5241c7e85a6e1e98a4ee20ae28 (diff)
parent58ef39d75098bce92654492e09edf1e83033d0c8 (diff)
Merge branch 'python3.8' into python3.7
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 6d8d4e1..f34f993 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -31,8 +31,8 @@ class FrameLevelPartFeatureExtractor(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- t, n, c, h, w = x.size()
- x = x.view(-1, c, h, w)
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
@@ -76,8 +76,8 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- # p, t, n, c
- x = x.permute(0, 2, 3, 1).contiguous()
+ # p, n, t, c
+ x = x.transpose(2, 3)
p, n, c, t = x.size()
feature = x.split(1, dim=0)
feature = [f.squeeze(0) for f in feature]
@@ -135,19 +135,18 @@ class PartNet(nn.Module):
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
- t, n, _, _, _ = x.size()
- # t, n, c, h, w
+ n, t, _, _, _ = x.size()
x = self.fpfe(x)
- # t_n, c, h, w
+ # n * t x c x h x w
# Horizontal Pooling
_, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(t, n, c) for x_ in x]
+ x = [x_.view(n, t, c) for x_ in x]
x = torch.stack(x)
- # p, t, n, c
+ # p, n, t, c
x = self.tfa(x)
return x