diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 21:00:01 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-31 21:00:01 +0800 |
commit | 86421e899c87976d8559795979415e3fae2bd7ed (patch) | |
tree | c4bfa828da64cc71b43ed12d3fd78be8a8930181 | |
parent | 57275a210b93f9ffd30a53e22c3c28f49f228d14 (diff) |
Implement some parts of RGB-GaitPart wrapper
1. Triplet loss function and weight init function haven't been implement yet
2. Tuplize features returned by auto-encoder for later unpack
3. Correct comment error in auto-encoder
4. Swap batch_size dim and time dim in HPM and PartNet in case of redundant transpose
5. Find backbone problems in HPM and disable it temporarily
6. Make feature structure by HPM consistent to that by PartNet
7. Fix average pooling dimension issue and incorrect view change in HP
-rw-r--r-- | models/__init__.py | 3 | ||||
-rw-r--r-- | models/auto_encoder.py | 4 | ||||
-rw-r--r-- | models/hpm.py | 29 | ||||
-rw-r--r-- | models/part_net.py | 28 | ||||
-rw-r--r-- | models/rgb_part_net.py | 98 |
5 files changed, 135 insertions, 27 deletions
diff --git a/models/__init__.py b/models/__init__.py index 3b4d86e..51c86af 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,4 @@ from .model import Model +from .auto_encoder import AutoEncoder +from .hpm import HorizontalPyramidMatching +from .part_net import PartNet diff --git a/models/auto_encoder.py b/models/auto_encoder.py index de38572..c84061c 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -129,7 +129,7 @@ class AutoEncoder(nn.Module): self.xent_loss = nn.CrossEntropyLoss() def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): - # t2 is random time step + # t1 is random time step (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) @@ -142,5 +142,5 @@ class AutoEncoder(nn.Module): + self.mse_loss(f_c_c1_t2, f_c_c2_t2) + self.xent_loss(y, y_)) - return (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2, + return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), xrecon_loss_t2, cano_cons_loss_t2) diff --git a/models/hpm.py b/models/hpm.py index 4a1f1a4..5553094 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -8,20 +8,25 @@ from models.layers import HorizontalPyramidPooling class HorizontalPyramidMatching(nn.Module): def __init__( self, + in_channels: int = 3, + out_channels: int = 128, scales: tuple[int, ...] = (1, 2, 4, 8), - out_channels: int = 256, use_avg_pool: bool = True, use_max_pool: bool = True, + use_backbone: bool = False, **kwargs ): super().__init__() - self.scales = scales + self.in_channels = in_channels self.out_channels = out_channels + self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool + self.use_backbone = use_backbone - self.backbone = resnet50(pretrained=True) - self.in_channels = self.backbone.layer4[-1].conv1.in_channels + if self.use_backbone: + self.backbone = resnet50(pretrained=True) + self.in_channels = self.backbone.layer4[-1].conv1.in_channels self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales @@ -40,12 +45,14 @@ class HorizontalPyramidMatching(nn.Module): def forward(self, x): # Flatten frames in all batches - n, t, c, h, w = x.size() + t, n, c, h, w = x.size() x = x.view(-1, c, h, w) - x = self.backbone(x) - n, c, h, w = x.size() + if self.use_backbone: + # FIXME Inconsistent dimensions + x = self.backbone(x) + t_n, _, h, _ = x.size() feature = [] for pyramid_index, pyramid in enumerate(self.pyramids): h_per_hpp = h // self.scales[pyramid_index] @@ -54,11 +61,11 @@ class HorizontalPyramidMatching(nn.Module): (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(n, -1) + x_slice = x_slice.view(t_n, -1) feature.append(x_slice) - x = torch.cat(feature, dim=1) + x = torch.stack(feature) # Unfold frames to original batch - _, d = x.size() - x = x.view(n, t, d) + p, _, c = x.size() + x = x.view(p, t, n, c) return x diff --git a/models/part_net.py b/models/part_net.py index 66e61fc..ac7c434 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -30,15 +30,11 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - n, t, c, h, w = x.size() + t, n, c, h, w = x.size() x = x.view(-1, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) - - # Unfold frames to original batch - _, c, h, w = x.size() - x = x.view(n, t, c, h, w) return x @@ -79,7 +75,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - x = x.transpose(2, 3) + # p, t, n, c + x = x.permute(0, 2, 3, 1).contiguous() p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -87,7 +84,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB1: ConvNet1d & Sigmoid logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x1, feature)] ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function @@ -97,7 +94,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB2: ConvNet1d & Sigmoid logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x3, feature)] ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function @@ -128,25 +125,28 @@ class PartNet(nn.Module): ) num_fconv_blocks = len(self.fpfe.fconv_blocks) - tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) + self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) self.tfa = TemporalFeatureAggregator( - tfa_in_channels, squeeze_ratio, self.num_part + self.tfa_in_channels, squeeze_ratio, self.num_part ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): + t, n, _, _, _ = x.size() + # t, n, c, h, w x = self.fpfe(x) + # t_n, c, h, w # Horizontal Pooling - n, t, c, h, w = x.size() + _, c, h, w = x.size() split_size = h // self.num_part - x = x.split(split_size, dim=3) + x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] + x = [x_.view(t, n, c) for x_ in x] x = torch.stack(x) - # p, n, t, c + # p, t, n, c x = self.tfa(x) return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py new file mode 100644 index 0000000..9768dec --- /dev/null +++ b/models/rgb_part_net.py @@ -0,0 +1,98 @@ +import random + +import torch +import torch.nn as nn + +from models import AutoEncoder, HorizontalPyramidMatching, PartNet + + +class RGBPartNet(nn.Module): + def __init__( + self, + num_class: int = 74, + ae_in_channels: int = 3, + ae_feature_channels: int = 64, + f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + hpm_scales: tuple[int, ...] = (1, 2, 4, 8), + hpm_use_avg_pool: bool = True, + hpm_use_max_pool: bool = True, + fpfe_feature_channels: int = 32, + fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), + fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), + fpfe_halving: tuple[int, ...] = (0, 2, 3), + tfa_squeeze_ratio: int = 4, + tfa_num_part: int = 16, + ): + super().__init__() + self.ae = AutoEncoder( + num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims + ) + self.pn = PartNet( + ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, + fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part + ) + self.hpm = HorizontalPyramidMatching( + ae_in_channels, self.pn.tfa_in_channels, hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool + ) + + self.mse_loss = nn.MSELoss() + + # TODO Weight inti here + + def pose_sim_loss(self, f_p_c1: torch.Tensor, + f_p_c2: torch.Tensor) -> torch.Tensor: + f_p_c1_mean = f_p_c1.mean(dim=0) + f_p_c2_mean = f_p_c2.mean(dim=0) + return self.mse_loss(f_p_c1_mean, f_p_c2_mean).item() + + def forward(self, x_c1, x_c2, y): + # Step 0: Swap batch_size and time dimensions for next step + # n, t, c, h, w + x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) + + # Step 1: Disentanglement + # t, n, c, h, w + num_frames = len(x_c1) + f_c_c1, f_p_c1, f_p_c2 = [], [], [] + xrecon_loss, cano_cons_loss = 0, 0 + for t2 in range(num_frames): + t1 = random.randrange(num_frames) + output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output + (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 + # Features for next step + f_c_c1.append(f_c_c1_t2) + f_p_c1.append(f_p_c1_t2) + # Losses per time step + f_p_c2.append(f_p_c2_t2) + xrecon_loss += xrecon_loss_t2 + cano_cons_loss += cano_cons_loss_t2 + f_c_c1 = torch.stack(f_c_c1) + f_p_c1 = torch.stack(f_p_c1) + + # Step 2.a: HPM & Static Gait Feature Aggregation + # t, n, c, h, w + x_c = self.hpm(f_c_c1) + # p, t, n, c + x_c = x_c.mean(dim=1) + # p, n, c + + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # t, n, c, h, w + x_p = self.pn(f_p_c1) + # p, n, c + + # Step 3: Cat feature map together and calculate losses + x = torch.cat(x_c, x_p) + # Losses + xrecon_loss /= num_frames + f_p_c2 = torch.stack(f_p_c2) + pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) + cano_cons_loss /= num_frames + # TODO Implement Batch All triplet loss function + batch_all_triplet_loss = 0 + loss = (xrecon_loss + pose_sim_loss + cano_cons_loss + + batch_all_triplet_loss) + + return x, loss |