diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/auto_encoder.py | 31 | ||||
| -rw-r--r-- | models/hpm.py | 22 | ||||
| -rw-r--r-- | models/rgb_part_net.py | 40 | 
3 files changed, 57 insertions, 36 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index c84061c..234111a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,10 +95,13 @@ class Decoder(nn.Module):          self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,                                                  is_last_layer=True) -    def forward(self, f_appearance, f_canonical, f_pose): +    def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):          x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)          x = self.fc(x)          x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) +        # Decode canonical features without transpose convolutions +        if no_trans_conv: +            return x          x = self.trans_conv1(x)          x = self.trans_conv2(x)          x = self.trans_conv3(x) @@ -131,16 +134,32 @@ class AutoEncoder(nn.Module):      def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):          # t1 is random time step          (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) -        (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) +        (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)          (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)          x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)          xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) -        y_ = self.classifier(f_c_c1_t2) +        y_ = self.classifier(f_c_c1_t2.contiguous())          cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)                               + self.mse_loss(f_c_c1_t2, f_c_c2_t2) -                             + self.xent_loss(y, y_)) +                             + self.xent_loss(y_, y)) -        return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), -                xrecon_loss_t2, cano_cons_loss_t2) +        f_a_size, f_c_size, f_p_size = ( +            f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() +        ) +        # Decode canonical features for HPM +        x_c_c1_t2 = self.decoder( +            torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), +            no_trans_conv=True +        ) +        # Decode pose features for Part Net +        x_p_c1_t2 = self.decoder( +            torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 +        ) + +        return ( +            (x_c_c1_t2, x_p_c1_t2), +            (f_p_c1_t2, f_p_c2_t2), +            (xrecon_loss_t2, cano_cons_loss_t2) +        ) diff --git a/models/hpm.py b/models/hpm.py index 5553094..66503e3 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -1,6 +1,5 @@  import torch  import torch.nn as nn -from torchvision.models import resnet50  from models.layers import HorizontalPyramidPooling @@ -8,12 +7,11 @@ from models.layers import HorizontalPyramidPooling  class HorizontalPyramidMatching(nn.Module):      def __init__(              self, -            in_channels: int = 3, +            in_channels: int,              out_channels: int = 128, -            scales: tuple[int, ...] = (1, 2, 4, 8), +            scales: tuple[int, ...] = (1, 2, 4),              use_avg_pool: bool = True,              use_max_pool: bool = True, -            use_backbone: bool = False,              **kwargs      ):          super().__init__() @@ -22,11 +20,6 @@ class HorizontalPyramidMatching(nn.Module):          self.scales = scales          self.use_avg_pool = use_avg_pool          self.use_max_pool = use_max_pool -        self.use_backbone = use_backbone - -        if self.use_backbone: -            self.backbone = resnet50(pretrained=True) -            self.in_channels = self.backbone.layer4[-1].conv1.in_channels          self.pyramids = nn.ModuleList([              self._make_pyramid(scale, **kwargs) for scale in self.scales @@ -44,15 +37,10 @@ class HorizontalPyramidMatching(nn.Module):          return pyramid      def forward(self, x): -        # Flatten frames in all batches +        # Flatten canonical features in all batches          t, n, c, h, w = x.size() -        x = x.view(-1, c, h, w) - -        if self.use_backbone: -            # FIXME Inconsistent dimensions -            x = self.backbone(x) +        x = x.view(t * n, c, h, w) -        t_n, _, h, _ = x.size()          feature = []          for pyramid_index, pyramid in enumerate(self.pyramids):              h_per_hpp = h // self.scales[pyramid_index] @@ -61,7 +49,7 @@ class HorizontalPyramidMatching(nn.Module):                                          (hpp_index + 1) * h_per_hpp)                  x_slice = x[:, :, h_filter, :]                  x_slice = hpp(x_slice) -                x_slice = x_slice.view(t_n, -1) +                x_slice = x_slice.view(t * n, -1)                  feature.append(x_slice)          x = torch.stack(feature) diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 377c108..0ff8251 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,7 @@ class RGBPartNet(nn.Module):              ae_in_channels: int = 3,              ae_feature_channels: int = 64,              f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), -            hpm_scales: tuple[int, ...] = (1, 2, 4, 8), +            hpm_scales: tuple[int, ...] = (1, 2, 4),              hpm_use_avg_pool: bool = True,              hpm_use_max_pool: bool = True,              fpfe_feature_channels: int = 32, @@ -32,7 +32,7 @@ class RGBPartNet(nn.Module):              fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part          )          self.hpm = HorizontalPyramidMatching( -            ae_in_channels, self.pn.tfa_in_channels, hpm_scales, +            ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales,              hpm_use_avg_pool, hpm_use_max_pool          ) @@ -54,38 +54,52 @@ class RGBPartNet(nn.Module):          # Step 1: Disentanglement          # t, n, c, h, w          num_frames = len(x_c1) -        f_c_c1, f_p_c1, f_p_c2 = [], [], [] +        # Decoded canonical features and Pose images +        x_c_c1, x_p_c1 = [], [] +        # Features required to calculate losses +        f_p_c1, f_p_c2 = [], []          xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1)          for t2 in range(num_frames):              t1 = random.randrange(num_frames)              output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) -            (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output -            (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2 -            # Features for next step -            f_c_c1.append(f_c_c1_t2) -            f_p_c1.append(f_p_c1_t2) +            (x_c1_t2, f_p_t2, losses) = output + +            # Decoded features or image +            (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 +            # Canonical Features for HPM +            x_c_c1.append(x_c_c1_t2) +            # Pose image for Part Net +            x_p_c1.append(x_p_c1_t2) +              # Losses per time step +            # Used in pose similarity loss +            (f_p_c1_t2, f_p_c2_t2) = f_p_t2 +            f_p_c1.append(f_p_c1_t2)              f_p_c2.append(f_p_c2_t2) +            # Cross reconstruction loss and canonical loss +            (xrecon_loss_t2, cano_cons_loss_t2) = losses              xrecon_loss += xrecon_loss_t2              cano_cons_loss += cano_cons_loss_t2 -        f_c_c1 = torch.stack(f_c_c1) -        f_p_c1 = torch.stack(f_p_c1) + +        x_c_c1 = torch.stack(x_c_c1) +        x_p_c1 = torch.stack(x_p_c1)          # Step 2.a: HPM & Static Gait Feature Aggregation          # t, n, c, h, w -        x_c = self.hpm(f_c_c1) +        x_c = self.hpm(x_c_c1)          # p, t, n, c          x_c = x_c.mean(dim=1)          # p, n, c          # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)          # t, n, c, h, w -        x_p = self.pn(f_p_c1) +        x_p = self.pn(x_p_c1)          # p, n, c          # Step 3: Cat feature map together and calculate losses -        x = torch.cat(x_c, x_p) +        x = torch.cat([x_c, x_p])          # Losses +        f_p_c1 = torch.stack(f_p_c1)          f_p_c2 = torch.stack(f_p_c2)          pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2)          cano_cons_loss /= num_frames  | 
