diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/part_net.py | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/models/part_net.py b/models/part_net.py index 2698e49..fbf1c88 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -28,8 +28,16 @@ class FrameLevelPartFeatureExtractor(nn.Module): for _params in zip(*params)] def forward(self, x): + # Flatten frames in all batches + n, t, c, h, w = x.size() + x = x.view(-1, c, h, w) + for fconv_block in self.fconv_blocks: x = fconv_block(x) + + # Unfold frames to original batch + _, c, h, w = x.size() + x = x.view(n, t, c, h, w) return x @@ -70,33 +78,29 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - """ - Input: x, [p, n, c, s] - """ - p, n, c, s = x.size() - feature = x.split(1, 0) - x = x.view(-1, c, s) + x = x.transpose(2, 3) + p, n, c, t = x.size() + feature = x.split(1, dim=0).squeeze(0) + x = x.view(-1, c, t) # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x1, feature)], dim=0 + logits3x1 = torch.stack( + [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, s) + feature3x1 = feature3x1.view(p, n, c, t) feature3x1 = feature3x1 * scores3x1 # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x3, feature)], dim=0 + logits3x3 = torch.stack( + [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, s) + feature3x3 = feature3x3.view(p, n, c, t) feature3x3 = feature3x3 * scores3x3 # Temporal Pooling @@ -138,8 +142,9 @@ class PartNet(nn.Module): split_size = h // self.num_part x = x.split(split_size, dim=3) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] - x = torch.cat(x, dim=3) + x = [x_.squeeze() for x_ in x] + x = torch.stack(x) + # p, n, t, c x = self.tfa(x) return x |