From 86421e899c87976d8559795979415e3fae2bd7ed Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 31 Dec 2020 21:00:01 +0800 Subject: Implement some parts of RGB-GaitPart wrapper 1. Triplet loss function and weight init function haven't been implement yet 2. Tuplize features returned by auto-encoder for later unpack 3. Correct comment error in auto-encoder 4. Swap batch_size dim and time dim in HPM and PartNet in case of redundant transpose 5. Find backbone problems in HPM and disable it temporarily 6. Make feature structure by HPM consistent to that by PartNet 7. Fix average pooling dimension issue and incorrect view change in HP --- models/part_net.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) (limited to 'models/part_net.py') diff --git a/models/part_net.py b/models/part_net.py index 66e61fc..ac7c434 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -30,15 +30,11 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - n, t, c, h, w = x.size() + t, n, c, h, w = x.size() x = x.view(-1, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) - - # Unfold frames to original batch - _, c, h, w = x.size() - x = x.view(n, t, c, h, w) return x @@ -79,7 +75,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - x = x.transpose(2, 3) + # p, t, n, c + x = x.permute(0, 2, 3, 1).contiguous() p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -87,7 +84,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB1: ConvNet1d & Sigmoid logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x1, feature)] ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function @@ -97,7 +94,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB2: ConvNet1d & Sigmoid logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x3, feature)] ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function @@ -128,25 +125,28 @@ class PartNet(nn.Module): ) num_fconv_blocks = len(self.fpfe.fconv_blocks) - tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) + self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) self.tfa = TemporalFeatureAggregator( - tfa_in_channels, squeeze_ratio, self.num_part + self.tfa_in_channels, squeeze_ratio, self.num_part ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): + t, n, _, _, _ = x.size() + # t, n, c, h, w x = self.fpfe(x) + # t_n, c, h, w # Horizontal Pooling - n, t, c, h, w = x.size() + _, c, h, w = x.size() split_size = h // self.num_part - x = x.split(split_size, dim=3) + x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] + x = [x_.view(t, n, c) for x_ in x] x = torch.stack(x) - # p, n, t, c + # p, t, n, c x = self.tfa(x) return x -- cgit v1.2.3