diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:31:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:31:52 +0800 |
commit | d380e04df37593e414bd5641db100613fb2ad882 (patch) | |
tree | 1e3b3ea55a464d59d790711372bbca42cb203d0a /models/auto_encoder.py | |
parent | a040400d7caa267d4bfbe8e5520568806f92b3d4 (diff) | |
parent | 99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/hpm.py
# models/layers.py
# models/model.py
# models/rgb_part_net.py
# utils/configuration.py
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 36 |
1 files changed, 7 insertions, 29 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index befd2d3..69dae4e 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -97,15 +97,14 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): + def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) - # Decode canonical features without transpose convolutions - if no_trans_conv: - return x x = self.trans_conv1(x) x = self.trans_conv2(x) + if cano_only: + return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) @@ -115,7 +114,6 @@ class Decoder(nn.Module): class AutoEncoder(nn.Module): def __init__( self, - num_class: int = 74, channels: int = 3, feature_channels: int = 64, embedding_dims: Tuple[int, int, int] = (128, 128, 64) @@ -124,27 +122,10 @@ class AutoEncoder(nn.Module): self.encoder = Encoder(channels, feature_channels, embedding_dims) self.decoder = Decoder(embedding_dims, feature_channels, channels) - f_c_dim = embedding_dims[1] - self.classifier = nn.Sequential( - nn.LeakyReLU(0.2, inplace=True), - BasicLinear(f_c_dim, num_class) - ) - - def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None): + def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - with torch.no_grad(): - # Decode canonical features for HPM - x_c_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), - no_trans_conv=True - ) - # Decode pose features for Part Net - x_p_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 - ) - if self.training: # t1 is random time step, c2 is another condition (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) @@ -152,16 +133,13 @@ class AutoEncoder(nn.Module): x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_) - - y_ = self.classifier(f_c_c1_t2.contiguous()) cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2) - + F.mse_loss(f_c_c1_t2, f_c_c2_t2) - + F.cross_entropy(y_, y)) + + F.mse_loss(f_c_c1_t2, f_c_c2_t2)) return ( - (x_c_c1_t2, x_p_c1_t2), + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2), (f_p_c1_t2, f_p_c2_t2), (xrecon_loss_t2, cano_cons_loss_t2) ) else: # evaluating - return x_c_c1_t2, x_p_c1_t2 + return f_c_c1_t2, f_p_c1_t2 |