diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 701f299..601348e 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -67,7 +67,7 @@ class Decoder(nn.Module): def __init__( self, - out_channels: int, + out_channels: int = 3, feature_channels: int = 64, input_dims: tuple[int, int, int] = (128, 128, 64) ): @@ -102,6 +102,6 @@ class Decoder(nn.Module): x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) - x = F.sigmoid(self.trans_conv4(x)) + x = torch.sigmoid(self.trans_conv4(x)) return x |