diff options
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index dc7843a..0c8bd5d 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -11,9 +13,9 @@ class Encoder(nn.Module): def __init__( self, in_channels: int = 3, - frame_size: tuple[int, int] = (64, 48), + frame_size: Tuple[int, int] = (64, 48), feature_channels: int = 64, - output_dims: tuple[int, int, int] = (192, 192, 128) + output_dims: Tuple[int, int, int] = (192, 192, 128) ): super().__init__() h_0, w_0 = frame_size @@ -102,9 +104,9 @@ class AutoEncoder(nn.Module): def __init__( self, channels: int = 3, - frame_size: tuple[int, int] = (64, 48), + frame_size: Tuple[int, int] = (64, 48), feature_channels: int = 64, - embedding_dims: tuple[int, int, int] = (192, 192, 128) + embedding_dims: Tuple[int, int, int] = (192, 192, 128) ): super().__init__() self.embedding_dims = embedding_dims |