diff options
Diffstat (limited to 'models/auto_encoder.py')
| -rw-r--r-- | models/auto_encoder.py | 31 | 
1 files changed, 25 insertions, 6 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index c84061c..234111a 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,10 +95,13 @@ class Decoder(nn.Module):          self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,                                                  is_last_layer=True) -    def forward(self, f_appearance, f_canonical, f_pose): +    def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):          x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)          x = self.fc(x)          x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) +        # Decode canonical features without transpose convolutions +        if no_trans_conv: +            return x          x = self.trans_conv1(x)          x = self.trans_conv2(x)          x = self.trans_conv3(x) @@ -131,16 +134,32 @@ class AutoEncoder(nn.Module):      def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):          # t1 is random time step          (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) -        (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) +        (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)          (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)          x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)          xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) -        y_ = self.classifier(f_c_c1_t2) +        y_ = self.classifier(f_c_c1_t2.contiguous())          cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)                               + self.mse_loss(f_c_c1_t2, f_c_c2_t2) -                             + self.xent_loss(y, y_)) +                             + self.xent_loss(y_, y)) -        return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2), -                xrecon_loss_t2, cano_cons_loss_t2) +        f_a_size, f_c_size, f_p_size = ( +            f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() +        ) +        # Decode canonical features for HPM +        x_c_c1_t2 = self.decoder( +            torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), +            no_trans_conv=True +        ) +        # Decode pose features for Part Net +        x_p_c1_t2 = self.decoder( +            torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 +        ) + +        return ( +            (x_c_c1_t2, x_p_c1_t2), +            (f_p_c1_t2, f_p_c2_t2), +            (xrecon_loss_t2, cano_cons_loss_t2) +        )  | 
