From 6b8cd51b008ff7a4384ef176571ffdd6e9e0792e Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 7 Jan 2021 19:55:00 +0800 Subject: Type hint for python version lower than 3.9 --- models/model.py | 14 +++++++------- utils/configuration.py | 4 ++-- utils/dataset.py | 18 +++++++++--------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/models/model.py b/models/model.py index 4deced0..725988a 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,5 @@ import os -from typing import Union, Optional, Tuple, List +from typing import Union, Optional, Tuple, List, Dict import numpy as np import torch @@ -166,7 +166,7 @@ class Model: popped_keys=['root_dir', 'cache_on'] ) self.log_name = '_'.join((self.log_name, self._dataset_sig)) - config: dict = dataset_config.copy() + config: Dict = dataset_config.copy() name = config.pop('name') if name == 'CASIA-B': return CASIAB(**config, is_train=self.is_train) @@ -180,7 +180,7 @@ class Model: dataset: Union[CASIAB], dataloader_config: DataloaderConfiguration ) -> DataLoader: - config: dict = dataloader_config.copy() + config: Dict = dataloader_config.copy() if self.is_train: (self.pr, self.k) = config.pop('batch_size') self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k))) @@ -195,9 +195,9 @@ class Model: def _batch_splitter( self, - batch: List[dict[str, Union[np.int64, str, torch.Tensor]]] - ) -> Tuple[dict[str, Union[List[str], torch.Tensor]], - dict[str, Union[List[str], torch.Tensor]]]: + batch: List[Dict[str, Union[np.int64, str, torch.Tensor]]] + ) -> Tuple[Dict[str, Union[List[str], torch.Tensor]], + Dict[str, Union[List[str], torch.Tensor]]]: """ Disentanglement need two random conditions, this function will split pr * k * 2 samples to 2 dicts each containing pr * k @@ -211,7 +211,7 @@ class Model: return default_collate(_batch[0]), default_collate(_batch[1]) def _make_signature(self, - config: dict, + config: Dict, popped_keys: Optional[List] = None) -> str: _config = config.copy() if popped_keys: diff --git a/utils/configuration.py b/utils/configuration.py index aa04b32..455abe8 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional, Union, Tuple +from typing import TypedDict, Optional, Union, Tuple, Dict from utils.dataset import ClipClasses, ClipConditions, ClipViews @@ -15,7 +15,7 @@ class DatasetConfiguration(TypedDict): train_size: int num_sampled_frames: int discard_threshold: int - selector: Optional[dict[str, Union[ClipClasses, ClipConditions, ClipViews]]] + selector: Optional[Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]] num_input_channels: int frame_size: Tuple[int, int] cache_on: bool diff --git a/utils/dataset.py b/utils/dataset.py index 0a33693..e691157 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -1,7 +1,7 @@ import os import random import re -from typing import Optional, NewType, Union, List, Tuple +from typing import Optional, NewType, Union, List, Tuple, Set, Dict import numpy as np import torch @@ -11,9 +11,9 @@ from sklearn.preprocessing import LabelEncoder from torch.utils import data from tqdm import tqdm -ClipClasses = NewType('ClipClasses', set[str]) -ClipConditions = NewType('ClipConditions', set[str]) -ClipViews = NewType('ClipViews', set[str]) +ClipClasses = NewType('ClipClasses', Set[str]) +ClipConditions = NewType('ClipConditions', Set[str]) +ClipViews = NewType('ClipViews', Set[str]) class CASIAB(data.Dataset): @@ -26,7 +26,7 @@ class CASIAB(data.Dataset): train_size: int = 74, num_sampled_frames: int = 30, discard_threshold: int = 15, - selector: Optional[dict[ + selector: Optional[Dict[ str, Union[ClipClasses, ClipConditions, ClipViews] ]] = None, num_input_channels: int = 3, @@ -75,12 +75,12 @@ class CASIAB(data.Dataset): self.views: np.ndarray[np.str_] # Labels, classes, conditions and views in dataset, # set of three attributes above - self.metadata = dict[str, List[np.int64, str]] + self.metadata = Dict[str, List[np.int64, str]] # Dictionaries for indexing frames and frame names by clip name # and chip path when cache is on - self._cached_clips_frame_names: Optional[dict[str, List[str]]] = None - self._cached_clips: Optional[dict[str, torch.Tensor]] = None + self._cached_clips_frame_names: Optional[Dict[str, List[str]]] = None + self._cached_clips: Optional[Dict[str, torch.Tensor]] = None # Video clip directory names self._clip_names: List[str] = [] @@ -170,7 +170,7 @@ class CASIAB(data.Dataset): def __getitem__( self, index: int - ) -> dict[str, Union[np.int64, str, torch.Tensor]]: + ) -> Dict[str, Union[np.int64, str, torch.Tensor]]: label = self.labels[index] condition = self.conditions[index] view = self.views[index] -- cgit v1.2.3