diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:34:31 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:34:31 +0800 | 
| commit | 31e0294cdb2ffd5241c7e85a6e1e98a4ee20ae28 (patch) | |
| tree | 709ddcf8ba175d09e9be4a91aab4c0eb61679c74 /models/rgb_part_net.py | |
| parent | afe615408c4003a513811d900fe3edd119a735a5 (diff) | |
| parent | d380e04df37593e414bd5641db100613fb2ad882 (diff) | |
Merge branch 'python3.8' into python3.7
# Conflicts:
#	utils/configuration.py
Diffstat (limited to 'models/rgb_part_net.py')
| -rw-r--r-- | models/rgb_part_net.py | 162 | 
1 files changed, 108 insertions, 54 deletions
| diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 326ec81..f6dc131 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -14,10 +14,10 @@ from utils.triplet_loss import BatchAllTripletLoss  class RGBPartNet(nn.Module):      def __init__(              self, -            num_class: int = 74,              ae_in_channels: int = 3,              ae_feature_channels: int = 64,              f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64), +            hpm_use_1x1conv: bool = False,              hpm_scales: Tuple[int, ...] = (1, 2, 4),              hpm_use_avg_pool: bool = True,              hpm_use_max_pool: bool = True, @@ -28,11 +28,16 @@ class RGBPartNet(nn.Module):              tfa_squeeze_ratio: int = 4,              tfa_num_parts: int = 16,              embedding_dims: int = 256, -            triplet_margin: float = 0.2 +            triplet_margins: Tuple[float, float] = (0.2, 0.2), +            image_log_on: bool = False      ):          super().__init__() +        (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims +        self.hpm_num_parts = sum(hpm_scales) +        self.image_log_on = image_log_on +          self.ae = AutoEncoder( -            num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims +            ae_in_channels, ae_feature_channels, f_a_c_p_dims          )          self.pn = PartNet(              ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, @@ -40,14 +45,16 @@ class RGBPartNet(nn.Module):          )          out_channels = self.pn.tfa_in_channels          self.hpm = HorizontalPyramidMatching( -            ae_feature_channels * 8, out_channels, hpm_scales, -            hpm_use_avg_pool, hpm_use_max_pool +            ae_feature_channels * 2, out_channels, hpm_use_1x1conv, +            hpm_scales, hpm_use_avg_pool, hpm_use_max_pool          ) -        total_parts = sum(hpm_scales) + tfa_num_parts -        empty_fc = torch.empty(total_parts, out_channels, embedding_dims) +        empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts, +                               out_channels, embedding_dims)          self.fc_mat = nn.Parameter(empty_fc) -        self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin) +        (hpm_margin, pn_margin) = triplet_margins +        self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) +        self.pn_ba_trip = BatchAllTripletLoss(pn_margin)      def fc(self, x):          return x @ self.fc_mat @@ -61,13 +68,11 @@ class RGBPartNet(nn.Module):          # Step 1: Disentanglement          # t, n, c, h, w -        ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y) +        ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2) -        # Step 2.a: HPM & Static Gait Feature Aggregation -        # t, n, c, h, w +        # Step 2.a: Static Gait Feature Aggregation & HPM +        # n, c, h, w          x_c = self.hpm(x_c_c1) -        # p, t, n, c -        x_c = x_c.mean(dim=1)          # p, n, c          # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) @@ -80,44 +85,83 @@ class RGBPartNet(nn.Module):          x = self.fc(x)          if self.training: -            batch_all_triplet_loss = self.ba_triplet_loss(x, y) -            losses = torch.stack((*losses, batch_all_triplet_loss)) -            return losses +            hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y) +            pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y) +            losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) +            return losses, images          else:              return x.unsqueeze(1).view(-1) -    def _disentangle(self, x_c1, x_c2=None, y=None): -        num_frames = len(x_c1) -        # Decoded canonical features and Pose images -        x_c_c1, x_p_c1 = [], [] +    def _disentangle(self, x_c1, x_c2=None): +        t, n, c, h, w = x_c1.size() +        device = x_c1.device          if self.training: +            # Encoded appearance, canonical and pose features +            f_a_c1, f_c_c1, f_p_c1 = [], [], []              # Features required to calculate losses -            f_p_c1, f_p_c2 = [], [] +            f_p_c2 = []              xrecon_loss, cano_cons_loss = [], [] -            for t2 in range(num_frames): -                t1 = random.randrange(num_frames) -                output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) -                (x_c1_t2, f_p_t2, losses) = output - -                # Decoded features or image -                (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 -                # Canonical Features for HPM -                x_c_c1.append(x_c_c1_t2) -                # Pose image for Part Net -                x_p_c1.append(x_p_c1_t2) +            for t2 in range(t): +                t1 = random.randrange(t) +                output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) +                (f_c1_t2, f_p_t2, losses) = output + +                (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2 +                if self.image_log_on: +                    f_a_c1.append(f_a_c1_t2) +                # Save canonical features and pose features +                f_c_c1.append(f_c_c1_t2) +                f_p_c1.append(f_p_c1_t2)                  # Losses per time step                  # Used in pose similarity loss -                (f_p_c1_t2, f_p_c2_t2) = f_p_t2 -                f_p_c1.append(f_p_c1_t2) +                (_, f_p_c2_t2) = f_p_t2                  f_p_c2.append(f_p_c2_t2) +                  # Cross reconstruction loss and canonical loss                  (xrecon_loss_t2, cano_cons_loss_t2) = losses                  xrecon_loss.append(xrecon_loss_t2)                  cano_cons_loss.append(cano_cons_loss_t2) - -            x_c_c1 = torch.stack(x_c_c1) -            x_p_c1 = torch.stack(x_p_c1) +            if self.image_log_on: +                f_a_c1 = torch.stack(f_a_c1) +            f_c_c1_mean = torch.stack(f_c_c1).mean(0) +            f_p_c1 = torch.stack(f_p_c1) +            f_p_c2 = torch.stack(f_p_c2) + +            # Decode features +            appearance_image, canonical_image, pose_image = None, None, None +            with torch.no_grad(): +                # Decode average canonical features to higher dimension +                x_c_c1 = self.ae.decoder( +                    torch.zeros((n, self.f_a_dim), device=device), +                    f_c_c1_mean, +                    torch.zeros((n, self.f_p_dim), device=device), +                    cano_only=True +                ) +                # Decode pose features to images +                f_p_c1_ = f_p_c1.view(t * n, -1) +                x_p_c1_ = self.ae.decoder( +                    torch.zeros((t * n, self.f_a_dim), device=device), +                    torch.zeros((t * n, self.f_c_dim), device=device), +                    f_p_c1_ +                ) +                x_p_c1 = x_p_c1_.view(t, n, c, h, w) + +                if self.image_log_on: +                    # Decode appearance features +                    f_a_c1_ = f_a_c1.view(t * n, -1) +                    appearance_image_ = self.ae.decoder( +                        f_a_c1_, +                        torch.zeros((t * n, self.f_c_dim), device=device), +                        torch.zeros((t * n, self.f_p_dim), device=device) +                    ) +                    appearance_image = appearance_image_.view(t, n, c, h, w) +                    # Continue decoding canonical features +                    canonical_image = self.ae.decoder.trans_conv3(x_c_c1) +                    canonical_image = torch.sigmoid( +                        self.ae.decoder.trans_conv4(canonical_image) +                    ) +                    pose_image = x_p_c1              # Losses              xrecon_loss = torch.sum(torch.stack(xrecon_loss)) @@ -125,26 +169,36 @@ class RGBPartNet(nn.Module):              cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))              return ((x_c_c1, x_p_c1), +                    (appearance_image, canonical_image, pose_image),                      (xrecon_loss, pose_sim_loss, cano_cons_loss))          else:  # evaluating -            for t2 in range(num_frames): -                x_c1_t2 = self.ae(x_c1[t2]) -                # Decoded features or image -                (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 -                # Canonical Features for HPM -                x_c_c1.append(x_c_c1_t2) -                # Pose image for Part Net -                x_p_c1.append(x_p_c1_t2) - -            x_c_c1 = torch.stack(x_c_c1) -            x_p_c1 = torch.stack(x_p_c1) - -            return (x_c_c1, x_p_c1), None +            x_c1_ = x_c1.view(t * n, c, h, w) +            (f_c_c1_, f_p_c1_) = self.ae(x_c1_) + +            # Canonical features +            f_c_c1 = f_c_c1_.view(t, n, -1) +            f_c_c1_mean = f_c_c1.mean(0) +            x_c_c1 = self.ae.decoder( +                torch.zeros((n, self.f_a_dim)), +                f_c_c1_mean, +                torch.zeros((n, self.f_p_dim)), +                cano_only=True +            ) + +            # Pose features +            x_p_c1_ = self.ae.decoder( +                torch.zeros((t * n, self.f_a_dim)), +                torch.zeros((t * n, self.f_c_dim)), +                f_p_c1_ +            ) +            x_p_c1 = x_p_c1_.view(t, n, c, h, w) + +            return (x_c_c1, x_p_c1), None, None      @staticmethod -    def _pose_sim_loss(f_p_c1: List[torch.Tensor], -                       f_p_c2: List[torch.Tensor]) -> torch.Tensor: -        f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0) -        f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0) +    def _pose_sim_loss(f_p_c1: torch.Tensor, +                       f_p_c2: torch.Tensor) -> torch.Tensor: +        f_p_c1_mean = f_p_c1.mean(dim=0) +        f_p_c2_mean = f_p_c2.mean(dim=0)          return F.mse_loss(f_p_c1_mean, f_p_c2_mean) | 
