summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-15 11:23:20 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-15 11:23:20 +0800
commitd51312415a32686793d3f0d14eda7fa7cc3990ea (patch)
tree9ec187d721c97a588f0207efe1311ceeee827d96 /models/auto_encoder.py
parentbe508061aeb3049a547c4e0c92d21c254689c1d5 (diff)
Revert "Memory usage improvement"
This reverts commit be508061
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py70
1 files changed, 30 insertions, 40 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 24a145d..a9312dd 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -117,47 +117,32 @@ class AutoEncoder(nn.Module):
embedding_dims: tuple[int, int, int] = (128, 128, 64)
):
super().__init__()
- self.f_c_c1_t2_ = None
- self.f_p_c1_t2_ = None
- self.f_c_c1_t1_ = None
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- def forward(self, x_t2, is_c1=True):
- n, t, c, h, w = x_t2.size()
- if is_c1: # condition 1
- # x_c1_t2 is the frame for later module
- x_c1_t2_ = x_t2.view(n * t, c, h, w)
- (f_a_c1_t2_, self.f_c_c1_t2_, self.f_p_c1_t2_) \
- = self.encoder(x_c1_t2_)
-
- if self.training:
- # t1 is random time step
- x_c1_t1 = x_t2[:, torch.randperm(t), :, :, :]
- x_c1_t1_ = x_c1_t1.view(n * t, c, h, w)
- (f_a_c1_t1_, self.f_c_c1_t1_, _) = self.encoder(x_c1_t1_)
-
- x_c1_t2_pred_ = self.decoder(
- f_a_c1_t1_, self.f_c_c1_t1_, self.f_p_c1_t2_
- )
- x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
-
- xrecon_loss = torch.stack([
- F.mse_loss(x_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
- for i in range(t)
- ]).sum()
-
- return ((f_a_c1_t2_, self.f_c_c1_t2_, self.f_p_c1_t2_),
- xrecon_loss)
- else: # evaluating
- return self.f_c_c1_t2_, self.f_p_c1_t2_
- else: # condition 2
- # c2 is another condition
- x_c2_t2_ = x_t2.view(n * t, c, h, w)
- (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2_)
-
- f_c_c1_t1 = self.f_c_c1_t1_.view(n, t, -1)
- f_c_c1_t2 = self.f_c_c1_t2_.view(n, t, -1)
+ def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
+ n, t, c, h, w = x_c1_t2.size()
+ # x_c1_t2 is the frame for later module
+ x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
+ (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
+
+ if self.training:
+ # t1 is random time step, c2 is another condition
+ x_c1_t1 = x_c1_t1.view(n * t, c, h, w)
+ (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1)
+ x_c2_t2 = x_c2_t2.view(n * t, c, h, w)
+ (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2)
+
+ x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
+ x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
+
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ for i in range(t)
+ ]).sum()
+
+ f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
+ f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
cano_cons_loss = torch.stack([
F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
@@ -165,8 +150,13 @@ class AutoEncoder(nn.Module):
for i in range(t)
]).mean()
- f_p_c1_t2 = self.f_p_c1_t2_.view(n, t, -1)
+ f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
- return cano_cons_loss, pose_sim_loss * 10
+ return (
+ (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
+ (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+ )
+ else: # evaluating
+ return f_c_c1_t2_, f_p_c1_t2_