diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 26 |
1 files changed, 5 insertions, 21 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 35cb629..f04ffdb 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,15 +95,14 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): + def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) - # Decode canonical features without transpose convolutions - if no_trans_conv: - return x x = self.trans_conv1(x) x = self.trans_conv2(x) + if cano_only: + return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) @@ -125,21 +124,6 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - with torch.no_grad(): - # Decode canonical features for HPM - x_c_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), - f_c_c1_t2, - torch.zeros_like(f_p_c1_t2), - no_trans_conv=True - ) - # Decode pose features for Part Net - x_p_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), - torch.zeros_like(f_c_c1_t2), - f_p_c1_t2 - ) - if self.training: # t1 is random time step, c2 is another condition (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) @@ -151,9 +135,9 @@ class AutoEncoder(nn.Module): + F.mse_loss(f_c_c1_t2, f_c_c2_t2)) return ( - (x_c_c1_t2, x_p_c1_t2), + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2), (f_p_c1_t2, f_p_c2_t2), (xrecon_loss_t2, cano_cons_loss_t2) ) else: # evaluating - return x_c_c1_t2, x_p_c1_t2 + return f_c_c1_t2, f_p_c1_t2 |