diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/part_net.py | 3 | 
1 files changed, 3 insertions, 0 deletions
diff --git a/models/part_net.py b/models/part_net.py index 2116600..2698e49 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -132,11 +132,14 @@ class PartNet(nn.Module):      def forward(self, x):          x = self.fpfe(x) + +        # Horizontal Pooling          n, t, c, h, w = x.size()          split_size = h // self.num_part          x = x.split(split_size, dim=3)          x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]          x = [x_.view(n, t, c, -1) for x_ in x]          x = torch.cat(x, dim=3) +          x = self.tfa(x)          return x  | 
