summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 701f299..601348e 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -67,7 +67,7 @@ class Decoder(nn.Module):
def __init__(
self,
- out_channels: int,
+ out_channels: int = 3,
feature_channels: int = 64,
input_dims: tuple[int, int, int] = (128, 128, 64)
):
@@ -102,6 +102,6 @@ class Decoder(nn.Module):
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
- x = F.sigmoid(self.trans_conv4(x))
+ x = torch.sigmoid(self.trans_conv4(x))
return x