summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 66e61fc..ac7c434 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -30,15 +30,11 @@ class FrameLevelPartFeatureExtractor(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- n, t, c, h, w = x.size()
+ t, n, c, h, w = x.size()
x = x.view(-1, c, h, w)
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
-
- # Unfold frames to original batch
- _, c, h, w = x.size()
- x = x.view(n, t, c, h, w)
return x
@@ -79,7 +75,8 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- x = x.transpose(2, 3)
+ # p, t, n, c
+ x = x.permute(0, 2, 3, 1).contiguous()
p, n, c, t = x.size()
feature = x.split(1, dim=0)
feature = [f.squeeze(0) for f in feature]
@@ -87,7 +84,7 @@ class TemporalFeatureAggregator(nn.Module):
# MTB1: ConvNet1d & Sigmoid
logits3x1 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0
+ [conv(f) for conv, f in zip(self.conv1d3x1, feature)]
)
scores3x1 = torch.sigmoid(logits3x1)
# MTB1: Template Function
@@ -97,7 +94,7 @@ class TemporalFeatureAggregator(nn.Module):
# MTB2: ConvNet1d & Sigmoid
logits3x3 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0
+ [conv(f) for conv, f in zip(self.conv1d3x3, feature)]
)
scores3x3 = torch.sigmoid(logits3x3)
# MTB2: Template Function
@@ -128,25 +125,28 @@ class PartNet(nn.Module):
)
num_fconv_blocks = len(self.fpfe.fconv_blocks)
- tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
+ self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
self.tfa = TemporalFeatureAggregator(
- tfa_in_channels, squeeze_ratio, self.num_part
+ self.tfa_in_channels, squeeze_ratio, self.num_part
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
+ t, n, _, _, _ = x.size()
+ # t, n, c, h, w
x = self.fpfe(x)
+ # t_n, c, h, w
# Horizontal Pooling
- n, t, c, h, w = x.size()
+ _, c, h, w = x.size()
split_size = h // self.num_part
- x = x.split(split_size, dim=3)
+ x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(n, t, c, -1) for x_ in x]
+ x = [x_.view(t, n, c) for x_ in x]
x = torch.stack(x)
- # p, n, t, c
+ # p, t, n, c
x = self.tfa(x)
return x