diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-09 21:21:57 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-09 21:21:57 +0800 |
commit | 916cf90d04e57fee23092c966740fbe94fd92cff (patch) | |
tree | e8d07b33601768fe30e309082c33b0c051382a72 /models/part_net.py | |
parent | 9f3d4cc14ad36e515b56e86fb8e26f519bde831e (diff) |
Improve performance when disentangling
This is a HUGE performance optimization, up to 2x faster than before. Mainly because of the replacement of randomized for-loop with randomized tensor.
Diffstat (limited to 'models/part_net.py')
-rw-r--r-- | models/part_net.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/models/part_net.py b/models/part_net.py index ac7c434..62a2bac 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -30,8 +30,8 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - t, n, c, h, w = x.size() - x = x.view(-1, c, h, w) + n, t, c, h, w = x.size() + x = x.view(n * t, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) @@ -75,8 +75,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - # p, t, n, c - x = x.permute(0, 2, 3, 1).contiguous() + # p, n, t, c + x = x.transpose(2, 3) p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -134,19 +134,18 @@ class PartNet(nn.Module): self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): - t, n, _, _, _ = x.size() - # t, n, c, h, w + n, t, _, _, _ = x.size() x = self.fpfe(x) - # t_n, c, h, w + # n * t x c x h x w # Horizontal Pooling _, c, h, w = x.size() split_size = h // self.num_part x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(t, n, c) for x_ in x] + x = [x_.view(n, t, c) for x_ in x] x = torch.stack(x) - # p, t, n, c + # p, n, t, c x = self.tfa(x) return x |