summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py57
1 files changed, 30 insertions, 27 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 234111a..ac3cfdf 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -122,35 +122,23 @@ class AutoEncoder(nn.Module):
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- f_c_dim = embedding_dims[1]
- self.classifier = nn.Sequential(
- nn.LeakyReLU(0.2, inplace=True),
- BasicLinear(f_c_dim, num_class)
- )
-
- self.mse_loss = nn.MSELoss()
- self.xent_loss = nn.CrossEntropyLoss()
-
- def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
- # t1 is random time step
- (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
+ if self.training:
+ f_c_dim = embedding_dims[1]
+ self.classifier = nn.Sequential(
+ nn.LeakyReLU(0.2, inplace=True),
+ BasicLinear(f_c_dim, num_class)
+ )
+
+ def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y=None):
+ # x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
-
- x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
- xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_)
-
- y_ = self.classifier(f_c_c1_t2.contiguous())
- cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)
- + self.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + self.xent_loss(y_, y))
f_a_size, f_c_size, f_p_size = (
f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
)
# Decode canonical features for HPM
x_c_c1_t2 = self.decoder(
- torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size),
+ torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size),
no_trans_conv=True
)
# Decode pose features for Part Net
@@ -158,8 +146,23 @@ class AutoEncoder(nn.Module):
torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
)
- return (
- (x_c_c1_t2, x_p_c1_t2),
- (f_p_c1_t2, f_p_c2_t2),
- (xrecon_loss_t2, cano_cons_loss_t2)
- )
+ if self.training:
+ # t1 is random time step, c2 is another condition
+ (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
+ (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
+
+ x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
+ xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_)
+
+ y_ = self.classifier(f_c_c1_t2.contiguous())
+ cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2)
+ + F.mse_loss(f_c_c1_t2, f_c_c2_t2)
+ + F.cross_entropy(y_, y))
+
+ return (
+ (x_c_c1_t2, x_p_c1_t2),
+ (f_p_c1_t2, f_p_c2_t2),
+ (xrecon_loss_t2, cano_cons_loss_t2)
+ )
+ else: # evaluating
+ return x_c_c1_t2, x_p_c1_t2