summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-09 21:28:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-09 21:28:38 +0800
commit58ef39d75098bce92654492e09edf1e83033d0c8 (patch)
tree8af7fe4fb5adfe1b189353dcff4efc38f62cd0c4 /models/part_net.py
parentd380e04df37593e414bd5641db100613fb2ad882 (diff)
parent916cf90d04e57fee23092c966740fbe94fd92cff (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/rgb_part_net.py
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 6d8d4e1..f34f993 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -31,8 +31,8 @@ class FrameLevelPartFeatureExtractor(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- t, n, c, h, w = x.size()
- x = x.view(-1, c, h, w)
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
@@ -76,8 +76,8 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- # p, t, n, c
- x = x.permute(0, 2, 3, 1).contiguous()
+ # p, n, t, c
+ x = x.transpose(2, 3)
p, n, c, t = x.size()
feature = x.split(1, dim=0)
feature = [f.squeeze(0) for f in feature]
@@ -135,19 +135,18 @@ class PartNet(nn.Module):
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
- t, n, _, _, _ = x.size()
- # t, n, c, h, w
+ n, t, _, _, _ = x.size()
x = self.fpfe(x)
- # t_n, c, h, w
+ # n * t x c x h x w
# Horizontal Pooling
_, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(t, n, c) for x_ in x]
+ x = [x_.view(n, t, c) for x_ in x]
x = torch.stack(x)
- # p, t, n, c
+ # p, n, t, c
x = self.tfa(x)
return x