From 8f5bef7f3d10ba0994ce51d9f84100c26218d6ee Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 23 Jan 2021 13:44:12 +0800 Subject: Transform all frames together in evaluation --- models/rgb_part_net.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index e707c26..2cc0958 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -86,15 +86,15 @@ class RGBPartNet(nn.Module): return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None, y=None): - num_frames = len(x_c1) - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] + t, n, c, h, w = x_c1.size() if self.training: + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] # Features required to calculate losses f_p_c1, f_p_c2 = [], [] xrecon_loss, cano_cons_loss = [], [] - for t2 in range(num_frames): - t1 = random.randrange(num_frames) + for t2 in range(t): + t1 = random.randrange(t) output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) (x_c1_t2, f_p_t2, losses) = output @@ -127,17 +127,11 @@ class RGBPartNet(nn.Module): (xrecon_loss, pose_sim_loss, cano_cons_loss)) else: # evaluating - for t2 in range(num_frames): - x_c1_t2 = self.ae(x_c1[t2]) - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + x_c1 = x_c1.view(-1, c, h, w) + x_c_c1, x_p_c1 = self.ae(x_c1) + _, c_c, h_c, w_c = x_c_c1.size() + x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c) + x_p_c1 = x_p_c1.view(t, n, c, h, w) return (x_c_c1, x_p_c1), None -- cgit v1.2.3 From 507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 23 Jan 2021 22:19:51 +0800 Subject: Remove the third term in canonical consistency loss --- models/rgb_part_net.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2cc0958..755d5dc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,6 @@ from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): def __init__( self, - num_class: int = 74, ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), @@ -31,7 +30,7 @@ class RGBPartNet(nn.Module): ): super().__init__() self.ae = AutoEncoder( - num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims + ae_in_channels, ae_feature_channels, f_a_c_p_dims ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, @@ -60,7 +59,7 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w - ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y) + ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w @@ -85,7 +84,7 @@ class RGBPartNet(nn.Module): else: return x.unsqueeze(1).view(-1) - def _disentangle(self, x_c1, x_c2=None, y=None): + def _disentangle(self, x_c1, x_c2=None): t, n, c, h, w = x_c1.size() if self.training: # Decoded canonical features and Pose images @@ -95,7 +94,7 @@ class RGBPartNet(nn.Module): xrecon_loss, cano_cons_loss = [], [] for t2 in range(t): t1 = random.randrange(t) - output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) + output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) (x_c1_t2, f_p_t2, losses) = output # Decoded features or image -- cgit v1.2.3 From 99ddd7c142a4ec97cb8bd14b204651790b3cf4ee Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 8 Feb 2021 18:11:25 +0800 Subject: Code refactoring, modifications and new features 1. Decode features outside of auto-encoder 2. Turn off HPM 1x1 conv by default 3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8` 4. Use mean of canonical embeddings instead of mean of static features 5. Calculate static and dynamic loss separately 6. Calculate mean of parts in triplet loss instead of sum of parts 7. Add switch to log disentangled images 8. Change default configuration --- models/rgb_part_net.py | 141 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 40 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 755d5dc..0e7d8b3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -16,6 +16,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, @@ -26,9 +27,14 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margin: float = 0.2 + triplet_margins: tuple[float, float] = (0.2, 0.2), + image_log_on: bool = False ): super().__init__() + (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims + self.hpm_num_parts = sum(hpm_scales) + self.image_log_on = image_log_on + self.ae = AutoEncoder( ae_in_channels, ae_feature_channels, f_a_c_p_dims ) @@ -38,14 +44,16 @@ class RGBPartNet(nn.Module): ) out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 8, out_channels, hpm_scales, - hpm_use_avg_pool, hpm_use_max_pool + ae_feature_channels * 2, out_channels, hpm_use_1x1conv, + hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) - total_parts = sum(hpm_scales) + tfa_num_parts - empty_fc = torch.empty(total_parts, out_channels, embedding_dims) + empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts, + out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin) + (hpm_margin, pn_margin) = triplet_margins + self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) + self.pn_ba_trip = BatchAllTripletLoss(pn_margin) def fc(self, x): return x @ self.fc_mat @@ -59,13 +67,11 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w - ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2) + ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2) - # Step 2.a: HPM & Static Gait Feature Aggregation - # t, n, c, h, w + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, c, h, w x_c = self.hpm(x_c_c1) - # p, t, n, c - x_c = x_c.mean(dim=1) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) @@ -78,44 +84,83 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - batch_all_triplet_loss = self.ba_triplet_loss(x, y) - losses = torch.stack((*losses, batch_all_triplet_loss)) - return losses + hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y) + pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y) + losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + return losses, images else: return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None): t, n, c, h, w = x_c1.size() + device = x_c1.device if self.training: - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] + # Encoded appearance, canonical and pose features + f_a_c1, f_c_c1, f_p_c1 = [], [], [] # Features required to calculate losses - f_p_c1, f_p_c2 = [], [] + f_p_c2 = [] xrecon_loss, cano_cons_loss = [], [] for t2 in range(t): t1 = random.randrange(t) output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) - (x_c1_t2, f_p_t2, losses) = output + (f_c1_t2, f_p_t2, losses) = output - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2 + if self.image_log_on: + f_a_c1.append(f_a_c1_t2) + # Save canonical features and pose features + f_c_c1.append(f_c_c1_t2) + f_p_c1.append(f_p_c1_t2) # Losses per time step # Used in pose similarity loss - (f_p_c1_t2, f_p_c2_t2) = f_p_t2 - f_p_c1.append(f_p_c1_t2) + (_, f_p_c2_t2) = f_p_t2 f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss (xrecon_loss_t2, cano_cons_loss_t2) = losses xrecon_loss.append(xrecon_loss_t2) cano_cons_loss.append(cano_cons_loss_t2) - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + if self.image_log_on: + f_a_c1 = torch.stack(f_a_c1) + f_c_c1_mean = torch.stack(f_c_c1).mean(0) + f_p_c1 = torch.stack(f_p_c1) + f_p_c2 = torch.stack(f_p_c2) + + # Decode features + appearance_image, canonical_image, pose_image = None, None, None + with torch.no_grad(): + # Decode average canonical features to higher dimension + x_c_c1 = self.ae.decoder( + torch.zeros((n, self.f_a_dim), device=device), + f_c_c1_mean, + torch.zeros((n, self.f_p_dim), device=device), + cano_only=True + ) + # Decode pose features to images + f_p_c1_ = f_p_c1.view(t * n, -1) + x_p_c1_ = self.ae.decoder( + torch.zeros((t * n, self.f_a_dim), device=device), + torch.zeros((t * n, self.f_c_dim), device=device), + f_p_c1_ + ) + x_p_c1 = x_p_c1_.view(t, n, c, h, w) + + if self.image_log_on: + # Decode appearance features + f_a_c1_ = f_a_c1.view(t * n, -1) + appearance_image_ = self.ae.decoder( + f_a_c1_, + torch.zeros((t * n, self.f_c_dim), device=device), + torch.zeros((t * n, self.f_p_dim), device=device) + ) + appearance_image = appearance_image_.view(t, n, c, h, w) + # Continue decoding canonical features + canonical_image = self.ae.decoder.trans_conv3(x_c_c1) + canonical_image = torch.sigmoid( + self.ae.decoder.trans_conv4(canonical_image) + ) + pose_image = x_p_c1 # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) @@ -123,20 +168,36 @@ class RGBPartNet(nn.Module): cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), + (appearance_image, canonical_image, pose_image), (xrecon_loss, pose_sim_loss, cano_cons_loss)) else: # evaluating - x_c1 = x_c1.view(-1, c, h, w) - x_c_c1, x_p_c1 = self.ae(x_c1) - _, c_c, h_c, w_c = x_c_c1.size() - x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c) - x_p_c1 = x_p_c1.view(t, n, c, h, w) - - return (x_c_c1, x_p_c1), None + x_c1_ = x_c1.view(t * n, c, h, w) + (f_c_c1_, f_p_c1_) = self.ae(x_c1_) + + # Canonical features + f_c_c1 = f_c_c1_.view(t, n, -1) + f_c_c1_mean = f_c_c1.mean(0) + x_c_c1 = self.ae.decoder( + torch.zeros((n, self.f_a_dim)), + f_c_c1_mean, + torch.zeros((n, self.f_p_dim)), + cano_only=True + ) + + # Pose features + x_p_c1_ = self.ae.decoder( + torch.zeros((t * n, self.f_a_dim)), + torch.zeros((t * n, self.f_c_dim)), + f_p_c1_ + ) + x_p_c1 = x_p_c1_.view(t, n, c, h, w) + + return (x_c_c1, x_p_c1), None, None @staticmethod - def _pose_sim_loss(f_p_c1: list[torch.Tensor], - f_p_c2: list[torch.Tensor]) -> torch.Tensor: - f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0) - f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0) + def _pose_sim_loss(f_p_c1: torch.Tensor, + f_p_c2: torch.Tensor) -> torch.Tensor: + f_p_c1_mean = f_p_c1.mean(dim=0) + f_p_c2_mean = f_p_c2.mean(dim=0) return F.mse_loss(f_p_c1_mean, f_p_c2_mean) -- cgit v1.2.3