diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 2d715db..c6bc52f 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -13,7 +15,7 @@ class Encoder(nn.Module): in_channels: int = 3, frame_size: tuple[int, int] = (64, 48), feature_channels: int = 64, - output_dims: tuple[int, int, int] = (128, 128, 64) + output_dims: Tuple[int, int, int] = (128, 128, 64) ): super().__init__() self.feature_channels = feature_channels @@ -74,7 +76,7 @@ class Decoder(nn.Module): def __init__( self, - input_dims: tuple[int, int, int] = (128, 128, 64), + input_dims: Tuple[int, int, int] = (128, 128, 64), feature_channels: int = 64, feature_size: tuple[int, int] = (4, 3), out_channels: int = 3, @@ -127,7 +129,7 @@ class AutoEncoder(nn.Module): channels: int = 3, frame_size: tuple[int, int] = (64, 48), feature_channels: int = 64, - embedding_dims: tuple[int, int, int] = (128, 128, 64) + embedding_dims: Tuple[int, int, int] = (128, 128, 64) ): super().__init__() self.encoder = Encoder(channels, frame_size, |