summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-30 22:16:22 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-30 22:16:22 +0800
commit2fddfd8f99f86f389117541421e457272f216d0b (patch)
treeb44922c780d251de2d7b4b5560e186c11b27c641
parent078657c1f42f62e5d3834d8a9f2c0226daae7320 (diff)
Correct and refine PartNet
1. Let FocalConv block capable of processing frames in all batches 2. Correct input dims of TFA and output dims of HP 3. Change torch.unsqueeze and torch.cat to torch.stack
-rw-r--r--models/part_net.py37
1 files changed, 21 insertions, 16 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 2698e49..fbf1c88 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -28,8 +28,16 @@ class FrameLevelPartFeatureExtractor(nn.Module):
for _params in zip(*params)]
def forward(self, x):
+ # Flatten frames in all batches
+ n, t, c, h, w = x.size()
+ x = x.view(-1, c, h, w)
+
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
+
+ # Unfold frames to original batch
+ _, c, h, w = x.size()
+ x = x.view(n, t, c, h, w)
return x
@@ -70,33 +78,29 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- """
- Input: x, [p, n, c, s]
- """
- p, n, c, s = x.size()
- feature = x.split(1, 0)
- x = x.view(-1, c, s)
+ x = x.transpose(2, 3)
+ p, n, c, t = x.size()
+ feature = x.split(1, dim=0).squeeze(0)
+ x = x.view(-1, c, t)
# MTB1: ConvNet1d & Sigmoid
- logits3x1 = torch.cat(
- [conv(_.squeeze(0)).unsqueeze(0)
- for conv, _ in zip(self.conv1d3x1, feature)], dim=0
+ logits3x1 = torch.stack(
+ [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0
)
scores3x1 = torch.sigmoid(logits3x1)
# MTB1: Template Function
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
- feature3x1 = feature3x1.view(p, n, c, s)
+ feature3x1 = feature3x1.view(p, n, c, t)
feature3x1 = feature3x1 * scores3x1
# MTB2: ConvNet1d & Sigmoid
- logits3x3 = torch.cat(
- [conv(_.squeeze(0)).unsqueeze(0)
- for conv, _ in zip(self.conv1d3x3, feature)], dim=0
+ logits3x3 = torch.stack(
+ [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0
)
scores3x3 = torch.sigmoid(logits3x3)
# MTB2: Template Function
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
- feature3x3 = feature3x3.view(p, n, c, s)
+ feature3x3 = feature3x3.view(p, n, c, t)
feature3x3 = feature3x3 * scores3x3
# Temporal Pooling
@@ -138,8 +142,9 @@ class PartNet(nn.Module):
split_size = h // self.num_part
x = x.split(split_size, dim=3)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(n, t, c, -1) for x_ in x]
- x = torch.cat(x, dim=3)
+ x = [x_.squeeze() for x_ in x]
+ x = torch.stack(x)
+ # p, n, t, c
x = self.tfa(x)
return x