From 916cf90d04e57fee23092c966740fbe94fd92cff Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 9 Feb 2021 21:21:57 +0800 Subject: Improve performance when disentangling This is a HUGE performance optimization, up to 2x faster than before. Mainly because of the replacement of randomized for-loop with randomized tensor. --- models/part_net.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) (limited to 'models/part_net.py') diff --git a/models/part_net.py b/models/part_net.py index ac7c434..62a2bac 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -30,8 +30,8 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - t, n, c, h, w = x.size() - x = x.view(-1, c, h, w) + n, t, c, h, w = x.size() + x = x.view(n * t, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) @@ -75,8 +75,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - # p, t, n, c - x = x.permute(0, 2, 3, 1).contiguous() + # p, n, t, c + x = x.transpose(2, 3) p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -134,19 +134,18 @@ class PartNet(nn.Module): self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): - t, n, _, _, _ = x.size() - # t, n, c, h, w + n, t, _, _, _ = x.size() x = self.fpfe(x) - # t_n, c, h, w + # n * t x c x h x w # Horizontal Pooling _, c, h, w = x.size() split_size = h // self.num_part x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(t, n, c) for x_ in x] + x = [x_.view(n, t, c) for x_ in x] x = torch.stack(x) - # p, t, n, c + # p, n, t, c x = self.tfa(x) return x -- cgit v1.2.3