From 5ad5121f702f23dcdcc9beef9aa6104d6269e179 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 27 Dec 2020 20:41:43 +0800 Subject: Fix inconsistency and API deprecation issues in decoder 1. Add default output channels of decoder 2. Replace deprecated torch.nn.functional.sigmoid with torch.sigmoid --- models/auto_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'models/auto_encoder.py') diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 701f299..601348e 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -67,7 +67,7 @@ class Decoder(nn.Module): def __init__( self, - out_channels: int, + out_channels: int = 3, feature_channels: int = 64, input_dims: tuple[int, int, int] = (128, 128, 64) ): @@ -102,6 +102,6 @@ class Decoder(nn.Module): x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) - x = F.sigmoid(self.trans_conv4(x)) + x = torch.sigmoid(self.trans_conv4(x)) return x -- cgit v1.2.3