From 02aaefaba26b6842d2feb403edfd71aaa75904da Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 2 Jan 2021 19:10:08 +0800 Subject: Correct feature dims after disentanglement and HPM backbone removal 1. Features used in HPM is decoded canonical embedding without transpose convolution 2. Decode pose embedding to image for Part Net 3. Backbone seems to be redundant, we can use feature map given by auto-decoder --- models/auto_encoder.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) (limited to 'models/auto_encoder.py') diff --git a/models/auto_encoder.py b/models/auto_encoder.py index c84061c..234111a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,10 +95,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose): + def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) + # Decode canonical features without transpose convolutions + if no_trans_conv: + return x x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) @@ -131,16 +134,32 @@ class AutoEncoder(nn.Module): def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): # t1 is random time step (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) - (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) - y_ = self.classifier(f_c_c1_t2) + y_ = self.classifier(f_c_c1_t2.contiguous()) cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) + self.mse_loss(f_c_c1_t2, f_c_c2_t2) - + self.xent_loss(y, y_)) + + self.xent_loss(y_, y)) - return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), - xrecon_loss_t2, cano_cons_loss_t2) + f_a_size, f_c_size, f_p_size = ( + f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() + ) + # Decode canonical features for HPM + x_c_c1_t2 = self.decoder( + torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), + no_trans_conv=True + ) + # Decode pose features for Part Net + x_p_c1_t2 = self.decoder( + torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + ) + + return ( + (x_c_c1_t2, x_p_c1_t2), + (f_p_c1_t2, f_p_c2_t2), + (xrecon_loss_t2, cano_cons_loss_t2) + ) -- cgit v1.2.3