summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py162
1 files changed, 108 insertions, 54 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 326ec81..f6dc131 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -14,10 +14,10 @@ from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
def __init__(
self,
- num_class: int = 74,
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
+ hpm_use_1x1conv: bool = False,
hpm_scales: Tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
@@ -28,11 +28,16 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: float = 0.2
+ triplet_margins: Tuple[float, float] = (0.2, 0.2),
+ image_log_on: bool = False
):
super().__init__()
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
+ self.hpm_num_parts = sum(hpm_scales)
+ self.image_log_on = image_log_on
+
self.ae = AutoEncoder(
- num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims
+ ae_in_channels, ae_feature_channels, f_a_c_p_dims
)
self.pn = PartNet(
ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
@@ -40,14 +45,16 @@ class RGBPartNet(nn.Module):
)
out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 8, out_channels, hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
+ ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
+ hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
)
- total_parts = sum(hpm_scales) + tfa_num_parts
- empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
+ empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts,
+ out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+ (hpm_margin, pn_margin) = triplet_margins
+ self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
+ self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
def fc(self, x):
return x @ self.fc_mat
@@ -61,13 +68,11 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
- ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y)
+ ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
- # Step 2.a: HPM & Static Gait Feature Aggregation
- # t, n, c, h, w
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, c, h, w
x_c = self.hpm(x_c_c1)
- # p, t, n, c
- x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
@@ -80,44 +85,83 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- batch_all_triplet_loss = self.ba_triplet_loss(x, y)
- losses = torch.stack((*losses, batch_all_triplet_loss))
- return losses
+ hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y)
+ pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y)
+ losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ return losses, images
else:
return x.unsqueeze(1).view(-1)
- def _disentangle(self, x_c1, x_c2=None, y=None):
- num_frames = len(x_c1)
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
+ def _disentangle(self, x_c1, x_c2=None):
+ t, n, c, h, w = x_c1.size()
+ device = x_c1.device
if self.training:
+ # Encoded appearance, canonical and pose features
+ f_a_c1, f_c_c1, f_p_c1 = [], [], []
# Features required to calculate losses
- f_p_c1, f_p_c2 = [], []
+ f_p_c2 = []
xrecon_loss, cano_cons_loss = [], []
- for t2 in range(num_frames):
- t1 = random.randrange(num_frames)
- output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y)
- (x_c1_t2, f_p_t2, losses) = output
-
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
+ for t2 in range(t):
+ t1 = random.randrange(t)
+ output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
+ (f_c1_t2, f_p_t2, losses) = output
+
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
+ if self.image_log_on:
+ f_a_c1.append(f_a_c1_t2)
+ # Save canonical features and pose features
+ f_c_c1.append(f_c_c1_t2)
+ f_p_c1.append(f_p_c1_t2)
# Losses per time step
# Used in pose similarity loss
- (f_p_c1_t2, f_p_c2_t2) = f_p_t2
- f_p_c1.append(f_p_c1_t2)
+ (_, f_p_c2_t2) = f_p_t2
f_p_c2.append(f_p_c2_t2)
+
# Cross reconstruction loss and canonical loss
(xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss.append(xrecon_loss_t2)
cano_cons_loss.append(cano_cons_loss_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ if self.image_log_on:
+ f_a_c1 = torch.stack(f_a_c1)
+ f_c_c1_mean = torch.stack(f_c_c1).mean(0)
+ f_p_c1 = torch.stack(f_p_c1)
+ f_p_c2 = torch.stack(f_p_c2)
+
+ # Decode features
+ appearance_image, canonical_image, pose_image = None, None, None
+ with torch.no_grad():
+ # Decode average canonical features to higher dimension
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ # Decode pose features to images
+ f_p_c1_ = f_p_c1.view(t * n, -1)
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim), device=device),
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ if self.image_log_on:
+ # Decode appearance features
+ f_a_c1_ = f_a_c1.view(t * n, -1)
+ appearance_image_ = self.ae.decoder(
+ f_a_c1_,
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ torch.zeros((t * n, self.f_p_dim), device=device)
+ )
+ appearance_image = appearance_image_.view(t, n, c, h, w)
+ # Continue decoding canonical features
+ canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
+ canonical_image = torch.sigmoid(
+ self.ae.decoder.trans_conv4(canonical_image)
+ )
+ pose_image = x_p_c1
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
@@ -125,26 +169,36 @@ class RGBPartNet(nn.Module):
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
+ (appearance_image, canonical_image, pose_image),
(xrecon_loss, pose_sim_loss, cano_cons_loss))
else: # evaluating
- for t2 in range(num_frames):
- x_c1_t2 = self.ae(x_c1[t2])
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
-
- return (x_c_c1, x_p_c1), None
+ x_c1_ = x_c1.view(t * n, c, h, w)
+ (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
+
+ # Canonical features
+ f_c_c1 = f_c_c1_.view(t, n, -1)
+ f_c_c1_mean = f_c_c1.mean(0)
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim)),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim)),
+ cano_only=True
+ )
+
+ # Pose features
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim)),
+ torch.zeros((t * n, self.f_c_dim)),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ return (x_c_c1, x_p_c1), None, None
@staticmethod
- def _pose_sim_loss(f_p_c1: List[torch.Tensor],
- f_p_c2: List[torch.Tensor]) -> torch.Tensor:
- f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0)
- f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0)
+ def _pose_sim_loss(f_p_c1: torch.Tensor,
+ f_p_c2: torch.Tensor) -> torch.Tensor:
+ f_p_c1_mean = f_p_c1.mean(dim=0)
+ f_p_c2_mean = f_p_c2.mean(dim=0)
return F.mse_loss(f_p_c1_mean, f_p_c2_mean)