summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/models/part_net.py b/models/part_net.py
index de19c8c..06884e9 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -128,23 +128,27 @@ class PartNet(nn.Module):
torch.empty(num_parts, in_channels, embedding_dims)
)
- def forward(self, x):
+ def _horizontal_pool(self, x):
n, t, c, h, w = x.size()
x = x.view(n * t, c, h, w)
- # n * t x c x h x w
-
- # Horizontal Pooling
- _, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
x = [x_.view(n, t, c) for x_ in x]
x = torch.stack(x)
+ return x
+ def forward(self, f_c1_t2, f_c2_t2=None):
+ # n, t, c, h, w
+ f_c1_t2_ = self._horizontal_pool(f_c1_t2)
# p, n, t, c
- x = self.tfa(x)
-
+ x = self.tfa(f_c1_t2_)
# p, n, c
x = x @ self.fc_mat
# p, n, d
- return x
+
+ if self.training:
+ f_c2_t2_ = self._horizontal_pool(f_c2_t2)
+ return x, (f_c1_t2_, f_c2_t2_)
+ else:
+ return x