diff options
-rw-r--r-- | models/auto_encoder.py | 31 | ||||
-rw-r--r-- | models/hpm.py | 22 | ||||
-rw-r--r-- | models/rgb_part_net.py | 40 |
3 files changed, 57 insertions, 36 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index c84061c..234111a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,10 +95,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose): + def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) + # Decode canonical features without transpose convolutions + if no_trans_conv: + return x x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) @@ -131,16 +134,32 @@ class AutoEncoder(nn.Module): def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): # t1 is random time step (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) - (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) - y_ = self.classifier(f_c_c1_t2) + y_ = self.classifier(f_c_c1_t2.contiguous()) cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) + self.mse_loss(f_c_c1_t2, f_c_c2_t2) - + self.xent_loss(y, y_)) + + self.xent_loss(y_, y)) - return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), - xrecon_loss_t2, cano_cons_loss_t2) + f_a_size, f_c_size, f_p_size = ( + f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() + ) + # Decode canonical features for HPM + x_c_c1_t2 = self.decoder( + torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), + no_trans_conv=True + ) + # Decode pose features for Part Net + x_p_c1_t2 = self.decoder( + torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + ) + + return ( + (x_c_c1_t2, x_p_c1_t2), + (f_p_c1_t2, f_p_c2_t2), + (xrecon_loss_t2, cano_cons_loss_t2) + ) diff --git a/models/hpm.py b/models/hpm.py index 5553094..66503e3 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from torchvision.models import resnet50 from models.layers import HorizontalPyramidPooling @@ -8,12 +7,11 @@ from models.layers import HorizontalPyramidPooling class HorizontalPyramidMatching(nn.Module): def __init__( self, - in_channels: int = 3, + in_channels: int, out_channels: int = 128, - scales: tuple[int, ...] = (1, 2, 4, 8), + scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, use_max_pool: bool = True, - use_backbone: bool = False, **kwargs ): super().__init__() @@ -22,11 +20,6 @@ class HorizontalPyramidMatching(nn.Module): self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool - self.use_backbone = use_backbone - - if self.use_backbone: - self.backbone = resnet50(pretrained=True) - self.in_channels = self.backbone.layer4[-1].conv1.in_channels self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales @@ -44,15 +37,10 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): - # Flatten frames in all batches + # Flatten canonical features in all batches t, n, c, h, w = x.size() - x = x.view(-1, c, h, w) - - if self.use_backbone: - # FIXME Inconsistent dimensions - x = self.backbone(x) + x = x.view(t * n, c, h, w) - t_n, _, h, _ = x.size() feature = [] for pyramid_index, pyramid in enumerate(self.pyramids): h_per_hpp = h // self.scales[pyramid_index] @@ -61,7 +49,7 @@ class HorizontalPyramidMatching(nn.Module): (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(t_n, -1) + x_slice = x_slice.view(t * n, -1) feature.append(x_slice) x = torch.stack(feature) diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 377c108..0ff8251 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_scales: tuple[int, ...] = (1, 2, 4, 8), + hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, fpfe_feature_channels: int = 32, @@ -32,7 +32,7 @@ class RGBPartNet(nn.Module): fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part ) self.hpm = HorizontalPyramidMatching( - ae_in_channels, self.pn.tfa_in_channels, hpm_scales, + ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) @@ -54,38 +54,52 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w num_frames = len(x_c1) - f_c_c1, f_p_c1, f_p_c2 = [], [], [] + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] + # Features required to calculate losses + f_p_c1, f_p_c2 = [], [] xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1) for t2 in range(num_frames): t1 = random.randrange(num_frames) output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) - (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output - (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 - # Features for next step - f_c_c1.append(f_c_c1_t2) - f_p_c1.append(f_p_c1_t2) + (x_c1_t2, f_p_t2, losses) = output + + # Decoded features or image + (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 + # Canonical Features for HPM + x_c_c1.append(x_c_c1_t2) + # Pose image for Part Net + x_p_c1.append(x_p_c1_t2) + # Losses per time step + # Used in pose similarity loss + (f_p_c1_t2, f_p_c2_t2) = f_p_t2 + f_p_c1.append(f_p_c1_t2) f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss + (xrecon_loss_t2, cano_cons_loss_t2) = losses xrecon_loss += xrecon_loss_t2 cano_cons_loss += cano_cons_loss_t2 - f_c_c1 = torch.stack(f_c_c1) - f_p_c1 = torch.stack(f_p_c1) + + x_c_c1 = torch.stack(x_c_c1) + x_p_c1 = torch.stack(x_p_c1) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w - x_c = self.hpm(f_c_c1) + x_c = self.hpm(x_c_c1) # p, t, n, c x_c = x_c.mean(dim=1) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # t, n, c, h, w - x_p = self.pn(f_p_c1) + x_p = self.pn(x_p_c1) # p, n, c # Step 3: Cat feature map together and calculate losses - x = torch.cat(x_c, x_p) + x = torch.cat([x_c, x_p]) # Losses + f_p_c1 = torch.stack(f_p_c1) f_p_c2 = torch.stack(f_p_c2) pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) cano_cons_loss /= num_frames |