From 7460ce2c8904a009f2f1139b11ec18faf208d6d2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 23 Dec 2020 20:15:14 +0800 Subject: Reshape feature before decode --- models/auto_encoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'models') diff --git a/models/auto_encoder.py b/models/auto_encoder.py index e35ed23..1708bc9 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -79,9 +79,10 @@ class Decoder(nn.Module): def forward(self, fa, fgs, fgd): x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) x = F.leaky_relu(self.batch_norm_fc(self.fc(x)), 0.2) + x = x.view(-1, 64 * 8, 4, 2) x = F.leaky_relu(self.batch_norm1(self.trans_conv1(x)), 0.2) x = F.leaky_relu(self.batch_norm2(self.trans_conv2(x)), 0.2) x = F.leaky_relu(self.batch_norm3(self.trans_conv3(x)), 0.2) x = F.sigmoid(self.trans_conv4(x)) - return x \ No newline at end of file + return x -- cgit v1.2.3