diff options
Diffstat (limited to 'models/part_net.py')
-rw-r--r-- | models/part_net.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/models/part_net.py b/models/part_net.py index 6d8d4e1..f34f993 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -31,8 +31,8 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - t, n, c, h, w = x.size() - x = x.view(-1, c, h, w) + n, t, c, h, w = x.size() + x = x.view(n * t, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) @@ -76,8 +76,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - # p, t, n, c - x = x.permute(0, 2, 3, 1).contiguous() + # p, n, t, c + x = x.transpose(2, 3) p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -135,19 +135,18 @@ class PartNet(nn.Module): self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): - t, n, _, _, _ = x.size() - # t, n, c, h, w + n, t, _, _, _ = x.size() x = self.fpfe(x) - # t_n, c, h, w + # n * t x c x h x w # Horizontal Pooling _, c, h, w = x.size() split_size = h // self.num_part x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(t, n, c) for x_ in x] + x = [x_.view(n, t, c) for x_ in x] x = torch.stack(x) - # p, t, n, c + # p, n, t, c x = self.tfa(x) return x |