From 2fddfd8f99f86f389117541421e457272f216d0b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 30 Dec 2020 22:16:22 +0800 Subject: Correct and refine PartNet 1. Let FocalConv block capable of processing frames in all batches 2. Correct input dims of TFA and output dims of HP 3. Change torch.unsqueeze and torch.cat to torch.stack --- models/part_net.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) (limited to 'models') diff --git a/models/part_net.py b/models/part_net.py index 2698e49..fbf1c88 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -28,8 +28,16 @@ class FrameLevelPartFeatureExtractor(nn.Module): for _params in zip(*params)] def forward(self, x): + # Flatten frames in all batches + n, t, c, h, w = x.size() + x = x.view(-1, c, h, w) + for fconv_block in self.fconv_blocks: x = fconv_block(x) + + # Unfold frames to original batch + _, c, h, w = x.size() + x = x.view(n, t, c, h, w) return x @@ -70,33 +78,29 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - """ - Input: x, [p, n, c, s] - """ - p, n, c, s = x.size() - feature = x.split(1, 0) - x = x.view(-1, c, s) + x = x.transpose(2, 3) + p, n, c, t = x.size() + feature = x.split(1, dim=0).squeeze(0) + x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x1, feature)], dim=0 + logits3x1 = torch.stack( + [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, s) + feature3x1 = feature3x1.view(p, n, c, t) feature3x1 = feature3x1 * scores3x1 # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x3, feature)], dim=0 + logits3x3 = torch.stack( + [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, s) + feature3x3 = feature3x3.view(p, n, c, t) feature3x3 = feature3x3 * scores3x3 # Temporal Pooling @@ -138,8 +142,9 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] - x = torch.cat(x, dim=3) + x = [x_.squeeze() for x_ in x] + x = torch.stack(x) + # p, n, t, c x = self.tfa(x) return x -- cgit v1.2.3