summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py70
1 files changed, 40 insertions, 30 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index a9312dd..24a145d 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -117,32 +117,47 @@ class AutoEncoder(nn.Module):
embedding_dims: tuple[int, int, int] = (128, 128, 64)
):
super().__init__()
+ self.f_c_c1_t2_ = None
+ self.f_p_c1_t2_ = None
+ self.f_c_c1_t1_ = None
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
- n, t, c, h, w = x_c1_t2.size()
- # x_c1_t2 is the frame for later module
- x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
- (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
-
- if self.training:
- # t1 is random time step, c2 is another condition
- x_c1_t1 = x_c1_t1.view(n * t, c, h, w)
- (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1)
- x_c2_t2 = x_c2_t2.view(n * t, c, h, w)
- (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2)
-
- x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
- x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
-
- xrecon_loss = torch.stack([
- F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
- for i in range(t)
- ]).sum()
-
- f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
- f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
+ def forward(self, x_t2, is_c1=True):
+ n, t, c, h, w = x_t2.size()
+ if is_c1: # condition 1
+ # x_c1_t2 is the frame for later module
+ x_c1_t2_ = x_t2.view(n * t, c, h, w)
+ (f_a_c1_t2_, self.f_c_c1_t2_, self.f_p_c1_t2_) \
+ = self.encoder(x_c1_t2_)
+
+ if self.training:
+ # t1 is random time step
+ x_c1_t1 = x_t2[:, torch.randperm(t), :, :, :]
+ x_c1_t1_ = x_c1_t1.view(n * t, c, h, w)
+ (f_a_c1_t1_, self.f_c_c1_t1_, _) = self.encoder(x_c1_t1_)
+
+ x_c1_t2_pred_ = self.decoder(
+ f_a_c1_t1_, self.f_c_c1_t1_, self.f_p_c1_t2_
+ )
+ x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
+
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ for i in range(t)
+ ]).sum()
+
+ return ((f_a_c1_t2_, self.f_c_c1_t2_, self.f_p_c1_t2_),
+ xrecon_loss)
+ else: # evaluating
+ return self.f_c_c1_t2_, self.f_p_c1_t2_
+ else: # condition 2
+ # c2 is another condition
+ x_c2_t2_ = x_t2.view(n * t, c, h, w)
+ (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2_)
+
+ f_c_c1_t1 = self.f_c_c1_t1_.view(n, t, -1)
+ f_c_c1_t2 = self.f_c_c1_t2_.view(n, t, -1)
f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
cano_cons_loss = torch.stack([
F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
@@ -150,13 +165,8 @@ class AutoEncoder(nn.Module):
for i in range(t)
]).mean()
- f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
+ f_p_c1_t2 = self.f_p_c1_t2_.view(n, t, -1)
f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
- return (
- (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
- (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
- )
- else: # evaluating
- return f_c_c1_t2_, f_p_c1_t2_
+ return cano_cons_loss, pose_sim_loss * 10