diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-30 22:16:22 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-30 22:16:22 +0800 |
commit | 2fddfd8f99f86f389117541421e457272f216d0b (patch) | |
tree | b44922c780d251de2d7b4b5560e186c11b27c641 /models/part_net.py | |
parent | 078657c1f42f62e5d3834d8a9f2c0226daae7320 (diff) |
Correct and refine PartNet
1. Let FocalConv block capable of processing frames in all batches
2. Correct input dims of TFA and output dims of HP
3. Change torch.unsqueeze and torch.cat to torch.stack
Diffstat (limited to 'models/part_net.py')
-rw-r--r-- | models/part_net.py | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/models/part_net.py b/models/part_net.py index 2698e49..fbf1c88 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -28,8 +28,16 @@ class FrameLevelPartFeatureExtractor(nn.Module): for _params in zip(*params)] def forward(self, x): + # Flatten frames in all batches + n, t, c, h, w = x.size() + x = x.view(-1, c, h, w) + for fconv_block in self.fconv_blocks: x = fconv_block(x) + + # Unfold frames to original batch + _, c, h, w = x.size() + x = x.view(n, t, c, h, w) return x @@ -70,33 +78,29 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - """ - Input: x, [p, n, c, s] - """ - p, n, c, s = x.size() - feature = x.split(1, 0) - x = x.view(-1, c, s) + x = x.transpose(2, 3) + p, n, c, t = x.size() + feature = x.split(1, dim=0).squeeze(0) + x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x1, feature)], dim=0 + logits3x1 = torch.stack( + [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, s) + feature3x1 = feature3x1.view(p, n, c, t) feature3x1 = feature3x1 * scores3x1 # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x3, feature)], dim=0 + logits3x3 = torch.stack( + [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, s) + feature3x3 = feature3x3.view(p, n, c, t) feature3x3 = feature3x3 * scores3x3 # Temporal Pooling @@ -138,8 +142,9 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] - x = torch.cat(x, dim=3) + x = [x_.squeeze() for x_ in x] + x = torch.stack(x) + # p, n, t, c x = self.tfa(x) return x |