summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-30 22:16:22 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-30 22:16:22 +0800
commit2fddfd8f99f86f389117541421e457272f216d0b (patch)
treeb44922c780d251de2d7b4b5560e186c11b27c641 /models/part_net.py
parent078657c1f42f62e5d3834d8a9f2c0226daae7320 (diff)
Correct and refine PartNet
1. Let FocalConv block capable of processing frames in all batches 2. Correct input dims of TFA and output dims of HP 3. Change torch.unsqueeze and torch.cat to torch.stack
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py37
1 files changed, 21 insertions, 16 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 2698e49..fbf1c88 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -28,8 +28,16 @@ class FrameLevelPartFeatureExtractor(nn.Module):
for _params in zip(*params)]
def forward(self, x):
+ # Flatten frames in all batches
+ n, t, c, h, w = x.size()
+ x = x.view(-1, c, h, w)
+
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
+
+ # Unfold frames to original batch
+ _, c, h, w = x.size()
+ x = x.view(n, t, c, h, w)
return x
@@ -70,33 +78,29 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- """
- Input: x, [p, n, c, s]
- """
- p, n, c, s = x.size()
- feature = x.split(1, 0)
- x = x.view(-1, c, s)
+ x = x.transpose(2, 3)
+ p, n, c, t = x.size()
+ feature = x.split(1, dim=0).squeeze(0)
+ x = x.view(-1, c, t)
# MTB1: ConvNet1d & Sigmoid
- logits3x1 = torch.cat(
- [conv(_.squeeze(0)).unsqueeze(0)
- for conv, _ in zip(self.conv1d3x1, feature)], dim=0
+ logits3x1 = torch.stack(
+ [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0
)
scores3x1 = torch.sigmoid(logits3x1)
# MTB1: Template Function
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
- feature3x1 = feature3x1.view(p, n, c, s)
+ feature3x1 = feature3x1.view(p, n, c, t)
feature3x1 = feature3x1 * scores3x1
# MTB2: ConvNet1d & Sigmoid
- logits3x3 = torch.cat(
- [conv(_.squeeze(0)).unsqueeze(0)
- for conv, _ in zip(self.conv1d3x3, feature)], dim=0
+ logits3x3 = torch.stack(
+ [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0
)
scores3x3 = torch.sigmoid(logits3x3)
# MTB2: Template Function
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
- feature3x3 = feature3x3.view(p, n, c, s)
+ feature3x3 = feature3x3.view(p, n, c, t)
feature3x3 = feature3x3 * scores3x3
# Temporal Pooling
@@ -138,8 +142,9 @@ class PartNet(nn.Module):
split_size = h // self.num_part
x = x.split(split_size, dim=3)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(n, t, c, -1) for x_ in x]
- x = torch.cat(x, dim=3)
+ x = [x_.squeeze() for x_ in x]
+ x = torch.stack(x)
+ # p, n, t, c
x = self.tfa(x)
return x