summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index e35ed23..1708bc9 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -79,9 +79,10 @@ class Decoder(nn.Module):
def forward(self, fa, fgs, fgd):
x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim)
x = F.leaky_relu(self.batch_norm_fc(self.fc(x)), 0.2)
+ x = x.view(-1, 64 * 8, 4, 2)
x = F.leaky_relu(self.batch_norm1(self.trans_conv1(x)), 0.2)
x = F.leaky_relu(self.batch_norm2(self.trans_conv2(x)), 0.2)
x = F.leaky_relu(self.batch_norm3(self.trans_conv3(x)), 0.2)
x = F.sigmoid(self.trans_conv4(x))
- return x \ No newline at end of file
+ return x