diff options
-rw-r--r-- | models/__init__.py | 3 | ||||
-rw-r--r-- | models/auto_encoder.py | 4 | ||||
-rw-r--r-- | models/hpm.py | 29 | ||||
-rw-r--r-- | models/part_net.py | 28 | ||||
-rw-r--r-- | models/rgb_part_net.py | 98 |
5 files changed, 135 insertions, 27 deletions
diff --git a/models/__init__.py b/models/__init__.py index 3b4d86e..51c86af 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,4 @@ from .model import Model +from .auto_encoder import AutoEncoder +from .hpm import HorizontalPyramidMatching +from .part_net import PartNet diff --git a/models/auto_encoder.py b/models/auto_encoder.py index de38572..c84061c 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -129,7 +129,7 @@ class AutoEncoder(nn.Module): self.xent_loss = nn.CrossEntropyLoss() def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): - # t2 is random time step + # t1 is random time step (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) @@ -142,5 +142,5 @@ class AutoEncoder(nn.Module): + self.mse_loss(f_c_c1_t2, f_c_c2_t2) + self.xent_loss(y, y_)) - return (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2, + return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), xrecon_loss_t2, cano_cons_loss_t2) diff --git a/models/hpm.py b/models/hpm.py index 4a1f1a4..5553094 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -8,20 +8,25 @@ from models.layers import HorizontalPyramidPooling class HorizontalPyramidMatching(nn.Module): def __init__( self, + in_channels: int = 3, + out_channels: int = 128, scales: tuple[int, ...] = (1, 2, 4, 8), - out_channels: int = 256, use_avg_pool: bool = True, use_max_pool: bool = True, + use_backbone: bool = False, **kwargs ): super().__init__() - self.scales = scales + self.in_channels = in_channels self.out_channels = out_channels + self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool + self.use_backbone = use_backbone - self.backbone = resnet50(pretrained=True) - self.in_channels = self.backbone.layer4[-1].conv1.in_channels + if self.use_backbone: + self.backbone = resnet50(pretrained=True) + self.in_channels = self.backbone.layer4[-1].conv1.in_channels self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales @@ -40,12 +45,14 @@ class HorizontalPyramidMatching(nn.Module): def forward(self, x): # Flatten frames in all batches - n, t, c, h, w = x.size() + t, n, c, h, w = x.size() x = x.view(-1, c, h, w) - x = self.backbone(x) - n, c, h, w = x.size() + if self.use_backbone: + # FIXME Inconsistent dimensions + x = self.backbone(x) + t_n, _, h, _ = x.size() feature = [] for pyramid_index, pyramid in enumerate(self.pyramids): h_per_hpp = h // self.scales[pyramid_index] @@ -54,11 +61,11 @@ class HorizontalPyramidMatching(nn.Module): (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(n, -1) + x_slice = x_slice.view(t_n, -1) feature.append(x_slice) - x = torch.cat(feature, dim=1) + x = torch.stack(feature) # Unfold frames to original batch - _, d = x.size() - x = x.view(n, t, d) + p, _, c = x.size() + x = x.view(p, t, n, c) return x diff --git a/models/part_net.py b/models/part_net.py index 66e61fc..ac7c434 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -30,15 +30,11 @@ class FrameLevelPartFeatureExtractor(nn.Module): def forward(self, x): # Flatten frames in all batches - n, t, c, h, w = x.size() + t, n, c, h, w = x.size() x = x.view(-1, c, h, w) for fconv_block in self.fconv_blocks: x = fconv_block(x) - - # Unfold frames to original batch - _, c, h, w = x.size() - x = x.view(n, t, c, h, w) return x @@ -79,7 +75,8 @@ class TemporalFeatureAggregator(nn.Module): for _ in range(self.num_part)]) def forward(self, x): - x = x.transpose(2, 3) + # p, t, n, c + x = x.permute(0, 2, 3, 1).contiguous() p, n, c, t = x.size() feature = x.split(1, dim=0) feature = [f.squeeze(0) for f in feature] @@ -87,7 +84,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB1: ConvNet1d & Sigmoid logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x1, feature)] ) scores3x1 = torch.sigmoid(logits3x1) # MTB1: Template Function @@ -97,7 +94,7 @@ class TemporalFeatureAggregator(nn.Module): # MTB2: ConvNet1d & Sigmoid logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0 + [conv(f) for conv, f in zip(self.conv1d3x3, feature)] ) scores3x3 = torch.sigmoid(logits3x3) # MTB2: Template Function @@ -128,25 +125,28 @@ class PartNet(nn.Module): ) num_fconv_blocks = len(self.fpfe.fconv_blocks) - tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) + self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) self.tfa = TemporalFeatureAggregator( - tfa_in_channels, squeeze_ratio, self.num_part + self.tfa_in_channels, squeeze_ratio, self.num_part ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): + t, n, _, _, _ = x.size() + # t, n, c, h, w x = self.fpfe(x) + # t_n, c, h, w # Horizontal Pooling - n, t, c, h, w = x.size() + _, c, h, w = x.size() split_size = h // self.num_part - x = x.split(split_size, dim=3) + x = x.split(split_size, dim=2) x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c, -1) for x_ in x] + x = [x_.view(t, n, c) for x_ in x] x = torch.stack(x) - # p, n, t, c + # p, t, n, c x = self.tfa(x) return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py new file mode 100644 index 0000000..9768dec --- /dev/null +++ b/models/rgb_part_net.py @@ -0,0 +1,98 @@ +import random + +import torch +import torch.nn as nn + +from models import AutoEncoder, HorizontalPyramidMatching, PartNet + + +class RGBPartNet(nn.Module): + def __init__( + self, + num_class: int = 74, + ae_in_channels: int = 3, + ae_feature_channels: int = 64, + f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + hpm_scales: tuple[int, ...] = (1, 2, 4, 8), + hpm_use_avg_pool: bool = True, + hpm_use_max_pool: bool = True, + fpfe_feature_channels: int = 32, + fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), + fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), + fpfe_halving: tuple[int, ...] = (0, 2, 3), + tfa_squeeze_ratio: int = 4, + tfa_num_part: int = 16, + ): + super().__init__() + self.ae = AutoEncoder( + num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims + ) + self.pn = PartNet( + ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, + fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part + ) + self.hpm = HorizontalPyramidMatching( + ae_in_channels, self.pn.tfa_in_channels, hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool + ) + + self.mse_loss = nn.MSELoss() + + # TODO Weight inti here + + def pose_sim_loss(self, f_p_c1: torch.Tensor, + f_p_c2: torch.Tensor) -> torch.Tensor: + f_p_c1_mean = f_p_c1.mean(dim=0) + f_p_c2_mean = f_p_c2.mean(dim=0) + return self.mse_loss(f_p_c1_mean, f_p_c2_mean).item() + + def forward(self, x_c1, x_c2, y): + # Step 0: Swap batch_size and time dimensions for next step + # n, t, c, h, w + x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) + + # Step 1: Disentanglement + # t, n, c, h, w + num_frames = len(x_c1) + f_c_c1, f_p_c1, f_p_c2 = [], [], [] + xrecon_loss, cano_cons_loss = 0, 0 + for t2 in range(num_frames): + t1 = random.randrange(num_frames) + output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output + (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 + # Features for next step + f_c_c1.append(f_c_c1_t2) + f_p_c1.append(f_p_c1_t2) + # Losses per time step + f_p_c2.append(f_p_c2_t2) + xrecon_loss += xrecon_loss_t2 + cano_cons_loss += cano_cons_loss_t2 + f_c_c1 = torch.stack(f_c_c1) + f_p_c1 = torch.stack(f_p_c1) + + # Step 2.a: HPM & Static Gait Feature Aggregation + # t, n, c, h, w + x_c = self.hpm(f_c_c1) + # p, t, n, c + x_c = x_c.mean(dim=1) + # p, n, c + + # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) + # t, n, c, h, w + x_p = self.pn(f_p_c1) + # p, n, c + + # Step 3: Cat feature map together and calculate losses + x = torch.cat(x_c, x_p) + # Losses + xrecon_loss /= num_frames + f_p_c2 = torch.stack(f_p_c2) + pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) + cano_cons_loss /= num_frames + # TODO Implement Batch All triplet loss function + batch_all_triplet_loss = 0 + loss = (xrecon_loss + pose_sim_loss + cano_cons_loss + + batch_all_triplet_loss) + + return x, loss |