diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-27 20:41:43 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-27 20:41:43 +0800 |
commit | 5ad5121f702f23dcdcc9beef9aa6104d6269e179 (patch) | |
tree | ab14c631508d5d6b3e8a9c38bcb6571970f6c642 /models/auto_encoder.py | |
parent | 5a95c94e9f250001d0007b5ac238505d0a5f23b5 (diff) |
Fix inconsistency and API deprecation issues in decoder
1. Add default output channels of decoder
2. Replace deprecated torch.nn.functional.sigmoid with torch.sigmoid
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 701f299..601348e 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -67,7 +67,7 @@ class Decoder(nn.Module): def __init__( self, - out_channels: int, + out_channels: int = 3, feature_channels: int = 64, input_dims: tuple[int, int, int] = (128, 128, 64) ): @@ -102,6 +102,6 @@ class Decoder(nn.Module): x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) - x = F.sigmoid(self.trans_conv4(x)) + x = torch.sigmoid(self.trans_conv4(x)) return x |