diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index c84061c..234111a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,10 +95,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose): + def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) + # Decode canonical features without transpose convolutions + if no_trans_conv: + return x x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) @@ -131,16 +134,32 @@ class AutoEncoder(nn.Module): def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): # t1 is random time step (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) - (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) - y_ = self.classifier(f_c_c1_t2) + y_ = self.classifier(f_c_c1_t2.contiguous()) cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) + self.mse_loss(f_c_c1_t2, f_c_c2_t2) - + self.xent_loss(y, y_)) + + self.xent_loss(y_, y)) - return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), - xrecon_loss_t2, cano_cons_loss_t2) + f_a_size, f_c_size, f_p_size = ( + f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() + ) + # Decode canonical features for HPM + x_c_c1_t2 = self.decoder( + torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), + no_trans_conv=True + ) + # Decode pose features for Part Net + x_p_c1_t2 = self.decoder( + torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 + ) + + return ( + (x_c_c1_t2, x_p_c1_t2), + (f_p_c1_t2, f_p_c2_t2), + (xrecon_loss_t2, cano_cons_loss_t2) + ) |