summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-27 20:41:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-27 20:41:43 +0800
commit5ad5121f702f23dcdcc9beef9aa6104d6269e179 (patch)
treeab14c631508d5d6b3e8a9c38bcb6571970f6c642 /models
parent5a95c94e9f250001d0007b5ac238505d0a5f23b5 (diff)
Fix inconsistency and API deprecation issues in decoder
1. Add default output channels of decoder 2. Replace deprecated torch.nn.functional.sigmoid with torch.sigmoid
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 701f299..601348e 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -67,7 +67,7 @@ class Decoder(nn.Module):
def __init__(
self,
- out_channels: int,
+ out_channels: int = 3,
feature_channels: int = 64,
input_dims: tuple[int, int, int] = (128, 128, 64)
):
@@ -102,6 +102,6 @@ class Decoder(nn.Module):
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
- x = F.sigmoid(self.trans_conv4(x))
+ x = torch.sigmoid(self.trans_conv4(x))
return x