summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/models/part_net.py b/models/part_net.py
index ac7c434..62a2bac 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -30,8 +30,8 @@ class FrameLevelPartFeatureExtractor(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- t, n, c, h, w = x.size()
- x = x.view(-1, c, h, w)
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
@@ -75,8 +75,8 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- # p, t, n, c
- x = x.permute(0, 2, 3, 1).contiguous()
+ # p, n, t, c
+ x = x.transpose(2, 3)
p, n, c, t = x.size()
feature = x.split(1, dim=0)
feature = [f.squeeze(0) for f in feature]
@@ -134,19 +134,18 @@ class PartNet(nn.Module):
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
- t, n, _, _, _ = x.size()
- # t, n, c, h, w
+ n, t, _, _, _ = x.size()
x = self.fpfe(x)
- # t_n, c, h, w
+ # n * t x c x h x w
# Horizontal Pooling
_, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(t, n, c) for x_ in x]
+ x = [x_.view(n, t, c) for x_ in x]
x = torch.stack(x)
- # p, t, n, c
+ # p, n, t, c
x = self.tfa(x)
return x