From 078657c1f42f62e5d3834d8a9f2c0226daae7320 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 30 Dec 2020 18:37:38 +0800 Subject: Combine FPFE and TFA to PartNet --- models/part_net.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'models') diff --git a/models/part_net.py b/models/part_net.py index 2116600..2698e49 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -132,11 +132,14 @@ class PartNet(nn.Module): def forward(self, x): x = self.fpfe(x) + + # Horizontal Pooling n, t, c, h, w = x.size() split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] x = [x_.view(n, t, c, -1) for x_ in x] x = torch.cat(x, dim=3) + x = self.tfa(x) return x -- cgit v1.2.3