diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 57 |
1 files changed, 30 insertions, 27 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 234111a..ac3cfdf 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -122,35 +122,23 @@ class AutoEncoder(nn.Module): self.encoder = Encoder(channels, feature_channels, embedding_dims) self.decoder = Decoder(embedding_dims, feature_channels, channels) - f_c_dim = embedding_dims[1] - self.classifier = nn.Sequential( - nn.LeakyReLU(0.2, inplace=True), - BasicLinear(f_c_dim, num_class) - ) - - self.mse_loss = nn.MSELoss() - self.xent_loss = nn.CrossEntropyLoss() - - def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): - # t1 is random time step - (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) + if self.training: + f_c_dim = embedding_dims[1] + self.classifier = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + BasicLinear(f_c_dim, num_class) + ) + + def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y=None): + # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) - - x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) - xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) - - y_ = self.classifier(f_c_c1_t2.contiguous()) - cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) - + self.mse_loss(f_c_c1_t2, f_c_c2_t2) - + self.xent_loss(y_, y)) f_a_size, f_c_size, f_p_size = ( f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() ) # Decode canonical features for HPM x_c_c1_t2 = self.decoder( - torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), + torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size), no_trans_conv=True ) # Decode pose features for Part Net @@ -158,8 +146,23 @@ class AutoEncoder(nn.Module): torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 ) - return ( - (x_c_c1_t2, x_p_c1_t2), - (f_p_c1_t2, f_p_c2_t2), - (xrecon_loss_t2, cano_cons_loss_t2) - ) + if self.training: + # t1 is random time step, c2 is another condition + (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) + (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) + + x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) + xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_) + + y_ = self.classifier(f_c_c1_t2.contiguous()) + cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2) + + F.mse_loss(f_c_c1_t2, f_c_c2_t2) + + F.cross_entropy(y_, y)) + + return ( + (x_c_c1_t2, x_p_c1_t2), + (f_p_c1_t2, f_p_c2_t2), + (xrecon_loss_t2, cano_cons_loss_t2) + ) + else: # evaluating + return x_c_c1_t2, x_p_c1_t2 |