summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-30 18:37:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-30 18:37:38 +0800
commit078657c1f42f62e5d3834d8a9f2c0226daae7320 (patch)
treeab2b9efbd756a20ec898acaf3f1ed9e1a2d3f200
parent5ba9390a7e4e8dbf366cfd280403cf37a6b22bab (diff)
Combine FPFE and TFA to PartNet
-rw-r--r--models/part_net.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 2116600..2698e49 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -132,11 +132,14 @@ class PartNet(nn.Module):
def forward(self, x):
x = self.fpfe(x)
+
+ # Horizontal Pooling
n, t, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=3)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
x = [x_.view(n, t, c, -1) for x_ in x]
x = torch.cat(x, dim=3)
+
x = self.tfa(x)
return x