diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-02 19:10:08 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-02 19:10:08 +0800 |
commit | 02aaefaba26b6842d2feb403edfd71aaa75904da (patch) | |
tree | 18744854e8d80e0239c0b2f3e7eaf39bc0a7974e /models/auto_encoder.py | |
parent | de8561d1d053730c5af03e1d06850efb60865d3c (diff) |
Correct feature dims after disentanglement and HPM backbone removal
1. Features used in HPM is decoded canonical embedding without transpose convolution
2. Decode pose embedding to image for Part Net
3. Backbone seems to be redundant, we can use feature map given by auto-decoder
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index c84061c..234111a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,10 +95,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose): + def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) + # Decode canonical features without transpose convolutions + if no_trans_conv: + return x x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) @@ -131,16 +134,32 @@ class AutoEncoder(nn.Module): def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): # t1 is random time step (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) - (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) - y_ = self.classifier(f_c_c1_t2) + y_ = self.classifier(f_c_c1_t2.contiguous()) cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) + self.mse_loss(f_c_c1_t2, f_c_c2_t2) - + self.xent_loss(y, y_)) + + self.xent_loss(y_, y)) - return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), - xrecon_loss_t2, cano_cons_loss_t2) + f_a_size, f_c_size, f_p_size = ( + f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() + ) + # Decode canonical features for HPM + x_c_c1_t2 = self.decoder( + torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), + no_trans_conv=True + ) + # Decode pose features for Part Net + x_p_c1_t2 = self.decoder( + torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + ) + + return ( + (x_c_c1_t2, x_p_c1_t2), + (f_p_c1_t2, f_p_c2_t2), + (xrecon_loss_t2, cano_cons_loss_t2) + ) |