summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-23 21:03:45 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-23 21:03:45 +0800
commit382912b087af409cc20628c711261c6bd3f99836 (patch)
tree601f042cae22f80ac8bb2d26409e160d281c7a89
parent46624a615429232cee01be670d925dd593ceb6a3 (diff)
Modify activation functions after conv or trans-conv in auto-encoder
1. Make activation functions be inplace ops 2. Change Leaky ReLU to ReLU in decoder
-rw-r--r--models/auto_encoder.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index bb4a377..1be878f 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -36,12 +36,12 @@ class Encoder(nn.Module):
self.batch_norm_fc = nn.BatchNorm1d(self.em_dim)
def forward(self, x):
- x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2)
+ x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2, inplace=True)
x = self.max_pool1(x)
- x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2)
+ x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2, inplace=True)
x = self.max_pool2(x)
- x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2)
- x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2)
+ x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2, inplace=True)
+ x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2, inplace=True)
x = self.max_pool3(x)
x = x.view(-1, (64 * 8) * 2 * 4)
embedding = self.batch_norm_fc(self.fc(x))
@@ -76,11 +76,11 @@ class Decoder(nn.Module):
def forward(self, fa, fgs, fgd):
x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim)
- x = F.leaky_relu(self.batch_norm_fc(self.fc(x)), 0.2)
+ x = F.relu(self.batch_norm_fc(self.fc(x)), True)
x = x.view(-1, 64 * 8, 4, 2)
- x = F.leaky_relu(self.batch_norm1(self.trans_conv1(x)), 0.2)
- x = F.leaky_relu(self.batch_norm2(self.trans_conv2(x)), 0.2)
- x = F.leaky_relu(self.batch_norm3(self.trans_conv3(x)), 0.2)
+ x = F.relu(self.batch_norm1(self.trans_conv1(x)), True)
+ x = F.relu(self.batch_norm2(self.trans_conv2(x)), True)
+ x = F.relu(self.batch_norm3(self.trans_conv3(x)), True)
x = F.sigmoid(self.trans_conv4(x))
return x