diff options
Diffstat (limited to 'models/part_net.py')
-rw-r--r-- | models/part_net.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/models/part_net.py b/models/part_net.py index 66e61fc..ac7c434 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -30,15 +30,11 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - n, t, c, h, w = x.size() + t, n, c, h, w = x.size() x = x.view(-1, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) - - # Unfold frames to original batch - _, c, h, w = x.size() - x = x.view(n, t, c, h, w) return x @@ -79,7 +75,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - x = x.transpose(2, 3) + # p, t, n, c + x = x.permute(0, 2, 3, 1).contiguous() p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -87,7 +84,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB1: ConvNet1d & Sigmoid logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x1, feature)] ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function @@ -97,7 +94,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB2: ConvNet1d & Sigmoid logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x3, feature)] ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function @@ -128,25 +125,28 @@ class PartNet(nn.Module): ) num_fconv_blocks = len(self.fpfe.fconv_blocks) - tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) + self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) self.tfa = TemporalFeatureAggregator( - tfa_in_channels, squeeze_ratio, self.num_part + self.tfa_in_channels, squeeze_ratio, self.num_part ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): + t, n, _, _, _ = x.size() + # t, n, c, h, w x = self.fpfe(x) + # t_n, c, h, w # Horizontal Pooling - n, t, c, h, w = x.size() + _, c, h, w = x.size() split_size = h // self.num_part - x = x.split(split_size, dim=3) + x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] + x = [x_.view(t, n, c) for x_ in x] x = torch.stack(x) - # p, n, t, c + # p, t, n, c x = self.tfa(x) return x |