summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:11:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:25:42 +0800
commit99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (patch)
treea4ccbd08a7155e90df63aba60eb93ab2b7969c9b /models/rgb_part_net.py
parent507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (diff)
Code refactoring, modifications and new features
1. Decode features outside of auto-encoder 2. Turn off HPM 1x1 conv by default 3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8` 4. Use mean of canonical embeddings instead of mean of static features 5. Calculate static and dynamic loss separately 6. Calculate mean of parts in triplet loss instead of sum of parts 7. Add switch to log disentangled images 8. Change default configuration
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py141
1 files changed, 101 insertions, 40 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 755d5dc..0e7d8b3 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -16,6 +16,7 @@ class RGBPartNet(nn.Module):
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
+ hpm_use_1x1conv: bool = False,
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
@@ -26,9 +27,14 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: float = 0.2
+ triplet_margins: tuple[float, float] = (0.2, 0.2),
+ image_log_on: bool = False
):
super().__init__()
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
+ self.hpm_num_parts = sum(hpm_scales)
+ self.image_log_on = image_log_on
+
self.ae = AutoEncoder(
ae_in_channels, ae_feature_channels, f_a_c_p_dims
)
@@ -38,14 +44,16 @@ class RGBPartNet(nn.Module):
)
out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 8, out_channels, hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
+ ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
+ hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
)
- total_parts = sum(hpm_scales) + tfa_num_parts
- empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
+ empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts,
+ out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+ (hpm_margin, pn_margin) = triplet_margins
+ self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
+ self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
def fc(self, x):
return x @ self.fc_mat
@@ -59,13 +67,11 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
- ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2)
+ ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
- # Step 2.a: HPM & Static Gait Feature Aggregation
- # t, n, c, h, w
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, c, h, w
x_c = self.hpm(x_c_c1)
- # p, t, n, c
- x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
@@ -78,44 +84,83 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- batch_all_triplet_loss = self.ba_triplet_loss(x, y)
- losses = torch.stack((*losses, batch_all_triplet_loss))
- return losses
+ hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y)
+ pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y)
+ losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ return losses, images
else:
return x.unsqueeze(1).view(-1)
def _disentangle(self, x_c1, x_c2=None):
t, n, c, h, w = x_c1.size()
+ device = x_c1.device
if self.training:
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
+ # Encoded appearance, canonical and pose features
+ f_a_c1, f_c_c1, f_p_c1 = [], [], []
# Features required to calculate losses
- f_p_c1, f_p_c2 = [], []
+ f_p_c2 = []
xrecon_loss, cano_cons_loss = [], []
for t2 in range(t):
t1 = random.randrange(t)
output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
- (x_c1_t2, f_p_t2, losses) = output
+ (f_c1_t2, f_p_t2, losses) = output
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
+ if self.image_log_on:
+ f_a_c1.append(f_a_c1_t2)
+ # Save canonical features and pose features
+ f_c_c1.append(f_c_c1_t2)
+ f_p_c1.append(f_p_c1_t2)
# Losses per time step
# Used in pose similarity loss
- (f_p_c1_t2, f_p_c2_t2) = f_p_t2
- f_p_c1.append(f_p_c1_t2)
+ (_, f_p_c2_t2) = f_p_t2
f_p_c2.append(f_p_c2_t2)
+
# Cross reconstruction loss and canonical loss
(xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss.append(xrecon_loss_t2)
cano_cons_loss.append(cano_cons_loss_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ if self.image_log_on:
+ f_a_c1 = torch.stack(f_a_c1)
+ f_c_c1_mean = torch.stack(f_c_c1).mean(0)
+ f_p_c1 = torch.stack(f_p_c1)
+ f_p_c2 = torch.stack(f_p_c2)
+
+ # Decode features
+ appearance_image, canonical_image, pose_image = None, None, None
+ with torch.no_grad():
+ # Decode average canonical features to higher dimension
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ # Decode pose features to images
+ f_p_c1_ = f_p_c1.view(t * n, -1)
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim), device=device),
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ if self.image_log_on:
+ # Decode appearance features
+ f_a_c1_ = f_a_c1.view(t * n, -1)
+ appearance_image_ = self.ae.decoder(
+ f_a_c1_,
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ torch.zeros((t * n, self.f_p_dim), device=device)
+ )
+ appearance_image = appearance_image_.view(t, n, c, h, w)
+ # Continue decoding canonical features
+ canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
+ canonical_image = torch.sigmoid(
+ self.ae.decoder.trans_conv4(canonical_image)
+ )
+ pose_image = x_p_c1
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
@@ -123,20 +168,36 @@ class RGBPartNet(nn.Module):
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
+ (appearance_image, canonical_image, pose_image),
(xrecon_loss, pose_sim_loss, cano_cons_loss))
else: # evaluating
- x_c1 = x_c1.view(-1, c, h, w)
- x_c_c1, x_p_c1 = self.ae(x_c1)
- _, c_c, h_c, w_c = x_c_c1.size()
- x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c)
- x_p_c1 = x_p_c1.view(t, n, c, h, w)
-
- return (x_c_c1, x_p_c1), None
+ x_c1_ = x_c1.view(t * n, c, h, w)
+ (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
+
+ # Canonical features
+ f_c_c1 = f_c_c1_.view(t, n, -1)
+ f_c_c1_mean = f_c_c1.mean(0)
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim)),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim)),
+ cano_only=True
+ )
+
+ # Pose features
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim)),
+ torch.zeros((t * n, self.f_c_dim)),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ return (x_c_c1, x_p_c1), None, None
@staticmethod
- def _pose_sim_loss(f_p_c1: list[torch.Tensor],
- f_p_c2: list[torch.Tensor]) -> torch.Tensor:
- f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0)
- f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0)
+ def _pose_sim_loss(f_p_c1: torch.Tensor,
+ f_p_c2: torch.Tensor) -> torch.Tensor:
+ f_p_c1_mean = f_p_c1.mean(dim=0)
+ f_p_c2_mean = f_p_c2.mean(dim=0)
return F.mse_loss(f_p_c1_mean, f_p_c2_mean)