From 5ba9390a7e4e8dbf366cfd280403cf37a6b22bab Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 30 Dec 2020 16:58:17 +0800 Subject: Combine FPFE and TFA to GaitPart --- models/fpfe.py | 38 -------------- models/layers.py | 27 ++++++++++ models/part_net.py | 142 +++++++++++++++++++++++++++++++++++++++++++++++++++++ models/tfa.py | 77 ----------------------------- 4 files changed, 169 insertions(+), 115 deletions(-) delete mode 100644 models/fpfe.py create mode 100644 models/part_net.py delete mode 100644 models/tfa.py (limited to 'models') diff --git a/models/fpfe.py b/models/fpfe.py deleted file mode 100644 index 28a0440..0000000 --- a/models/fpfe.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch.nn as nn - -from models.layers import FocalConv2d - - -class FrameLevelPartFeatureExtractor(nn.Module): - - def __init__(self, in_channels: int): - super(FrameLevelPartFeatureExtractor, self).__init__() - nf = 32 - - self.focal_conv1 = FocalConv2d(in_channels, nf, kernel_size=5, - padding=2, halving=1) - self.focal_conv2 = FocalConv2d(nf, nf, kernel_size=3, - padding=1, halving=1) - self.focal_conv3 = FocalConv2d(nf, nf * 2, kernel_size=3, - padding=1, halving=4) - self.focal_conv4 = FocalConv2d(nf * 2, nf * 2, kernel_size=3, - padding=1, halving=4) - self.focal_conv5 = FocalConv2d(nf * 2, nf * 4, kernel_size=3, - padding=1, halving=8) - self.focal_conv6 = FocalConv2d(nf * 4, nf * 4, kernel_size=3, - padding=1, halving=8) - self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) - - def forward(self, x): - x = self.focal_conv1(x) - x = self.focal_conv2(x) - x = self.max_pool(x) - - x = self.focal_conv3(x) - x = self.focal_conv4(x) - x = self.max_pool(x) - - x = self.focal_conv5(x) - x = self.focal_conv6(x) - - return x diff --git a/models/layers.py b/models/layers.py index 62a3cc6..c69ae07 100644 --- a/models/layers.py +++ b/models/layers.py @@ -119,6 +119,33 @@ class FocalConv2d(BasicConv2d): return F.leaky_relu(z, inplace=True) +class FocalConv2dBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_sizes: tuple[int, int], + paddings: tuple[int, int], + halving: int, + use_pool: bool = True, + **kwargs + ): + super().__init__() + self.use_pool = use_pool + self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0], + halving, padding=paddings[0], **kwargs) + self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1], + halving, padding=paddings[1], **kwargs) + self.max_pool = nn.MaxPool2d(2) + + def forward(self, x): + x = self.fconv1(x) + x = self.fconv2(x) + if self.use_pool: + x = self.max_pool(x) + return x + + class BasicConv1d(nn.Module): def __init__( self, diff --git a/models/part_net.py b/models/part_net.py new file mode 100644 index 0000000..2116600 --- /dev/null +++ b/models/part_net.py @@ -0,0 +1,142 @@ +import copy + +import torch +import torch.nn as nn + +from models.layers import BasicConv1d, FocalConv2dBlock + + +class FrameLevelPartFeatureExtractor(nn.Module): + + def __init__( + self, + in_channels: int = 3, + feature_channels: int = 32, + kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), + paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), + halving: tuple[int, ...] = (0, 2, 3) + ): + super().__init__() + num_blocks = len(kernel_sizes) + out_channels = [feature_channels * 2 ** i for i in range(num_blocks)] + in_channels = [in_channels] + out_channels[:-1] + use_pools = [True] * (num_blocks - 1) + [False] + params = (in_channels, out_channels, kernel_sizes, + paddings, halving, use_pools) + + self.fconv_blocks = [FocalConv2dBlock(*_params) + for _params in zip(*params)] + + def forward(self, x): + for fconv_block in self.fconv_blocks: + x = fconv_block(x) + return x + + +class TemporalFeatureAggregator(nn.Module): + def __init__( + self, + in_channels: int, + squeeze_ratio: int = 4, + num_part: int = 16 + ): + super().__init__() + hidden_dim = in_channels // squeeze_ratio + self.num_part = num_part + + # MTB1 + conv3x1 = nn.Sequential( + BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), + nn.LeakyReLU(inplace=True), + BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0) + ) + self.conv1d3x1 = self._parted(conv3x1) + self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) + self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + + # MTB2 + conv3x3 = nn.Sequential( + BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), + nn.LeakyReLU(inplace=True), + BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1) + ) + self.conv1d3x3 = self._parted(conv3x3) + self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) + self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) + + def _parted(self, module: nn.Module): + """Duplicate module `part_num` times.""" + return nn.ModuleList([copy.deepcopy(module) + for _ in range(self.num_part)]) + + def forward(self, x): + """ + Input: x, [p, n, c, s] + """ + p, n, c, s = x.size() + feature = x.split(1, 0) + x = x.view(-1, c, s) + + # MTB1: ConvNet1d & Sigmoid + logits3x1 = torch.cat( + [conv(_.squeeze(0)).unsqueeze(0) + for conv, _ in zip(self.conv1d3x1, feature)], dim=0 + ) + scores3x1 = torch.sigmoid(logits3x1) + # MTB1: Template Function + feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) + feature3x1 = feature3x1.view(p, n, c, s) + feature3x1 = feature3x1 * scores3x1 + + # MTB2: ConvNet1d & Sigmoid + logits3x3 = torch.cat( + [conv(_.squeeze(0)).unsqueeze(0) + for conv, _ in zip(self.conv1d3x3, feature)], dim=0 + ) + scores3x3 = torch.sigmoid(logits3x3) + # MTB2: Template Function + feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) + feature3x3 = feature3x3.view(p, n, c, s) + feature3x3 = feature3x3 * scores3x3 + + # Temporal Pooling + ret = (feature3x1 + feature3x3).max(-1)[0] + return ret + + +class PartNet(nn.Module): + def __init__( + self, + in_channels: int = 3, + feature_channels: int = 32, + kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), + paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), + halving: tuple[int, ...] = (0, 2, 3), + squeeze_ratio: int = 4, + num_part: int = 16 + ): + super().__init__() + self.num_part = num_part + self.fpfe = FrameLevelPartFeatureExtractor( + in_channels, feature_channels, kernel_sizes, paddings, halving + ) + + num_fconv_blocks = len(self.fpfe.fconv_blocks) + tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) + self.tfa = TemporalFeatureAggregator( + tfa_in_channels, squeeze_ratio, self.num_part + ) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + def forward(self, x): + x = self.fpfe(x) + n, t, c, h, w = x.size() + split_size = h // self.num_part + x = x.split(split_size, dim=3) + x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] + x = [x_.view(n, t, c, -1) for x_ in x] + x = torch.cat(x, dim=3) + x = self.tfa(x) + return x diff --git a/models/tfa.py b/models/tfa.py deleted file mode 100644 index b80328a..0000000 --- a/models/tfa.py +++ /dev/null @@ -1,77 +0,0 @@ -import copy - -import torch -import torch.nn as nn - -from models.layers import BasicConv1d - - -class TemporalFeatureAggregator(nn.Module): - def __init__( - self, - in_channels: int, - squeeze: int = 4, - num_part: int = 16 - ): - super(TemporalFeatureAggregator, self).__init__() - hidden_dim = in_channels // squeeze - self.num_part = num_part - - # MTB1 - conv3x1 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0) - ) - self.conv1d3x1 = self._parted(conv3x1) - self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) - self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - - # MTB2 - conv3x3 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1) - ) - self.conv1d3x3 = self._parted(conv3x3) - self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) - self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) - - def _parted(self, module: nn.Module): - """Duplicate module `part_num` times.""" - return nn.ModuleList([copy.deepcopy(module) - for _ in range(self.num_part)]) - - def forward(self, x): - """ - Input: x, [p, n, c, s] - """ - p, n, c, s = x.size() - feature = x.split(1, 0) - x = x.view(-1, c, s) - - # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x1, feature)], dim=0 - ) - scores3x1 = torch.sigmoid(logits3x1) - # MTB1: Template Function - feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, s) - feature3x1 = feature3x1 * scores3x1 - - # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x3, feature)], dim=0 - ) - scores3x3 = torch.sigmoid(logits3x3) - # MTB2: Template Function - feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, s) - feature3x3 = feature3x3 * scores3x3 - - # Temporal Pooling - ret = (feature3x1 + feature3x3).max(-1)[0] - return ret -- cgit v1.2.3