summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-07 19:55:00 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-07 20:10:54 +0800
commit6b8cd51b008ff7a4384ef176571ffdd6e9e0792e (patch)
tree6010b45d0d2564582514105322fee2d34cceba66
parent98b6e6dc3be6f88abb72e351c8f2da2b23b8ab85 (diff)
Type hint for python version lower than 3.9
-rw-r--r--models/model.py14
-rw-r--r--utils/configuration.py4
-rw-r--r--utils/dataset.py18
3 files changed, 18 insertions, 18 deletions
diff --git a/models/model.py b/models/model.py
index 4deced0..725988a 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,5 @@
import os
-from typing import Union, Optional, Tuple, List
+from typing import Union, Optional, Tuple, List, Dict
import numpy as np
import torch
@@ -166,7 +166,7 @@ class Model:
popped_keys=['root_dir', 'cache_on']
)
self.log_name = '_'.join((self.log_name, self._dataset_sig))
- config: dict = dataset_config.copy()
+ config: Dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
return CASIAB(**config, is_train=self.is_train)
@@ -180,7 +180,7 @@ class Model:
dataset: Union[CASIAB],
dataloader_config: DataloaderConfiguration
) -> DataLoader:
- config: dict = dataloader_config.copy()
+ config: Dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
@@ -195,9 +195,9 @@ class Model:
def _batch_splitter(
self,
- batch: List[dict[str, Union[np.int64, str, torch.Tensor]]]
- ) -> Tuple[dict[str, Union[List[str], torch.Tensor]],
- dict[str, Union[List[str], torch.Tensor]]]:
+ batch: List[Dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> Tuple[Dict[str, Union[List[str], torch.Tensor]],
+ Dict[str, Union[List[str], torch.Tensor]]]:
"""
Disentanglement need two random conditions, this function will
split pr * k * 2 samples to 2 dicts each containing pr * k
@@ -211,7 +211,7 @@ class Model:
return default_collate(_batch[0]), default_collate(_batch[1])
def _make_signature(self,
- config: dict,
+ config: Dict,
popped_keys: Optional[List] = None) -> str:
_config = config.copy()
if popped_keys:
diff --git a/utils/configuration.py b/utils/configuration.py
index aa04b32..455abe8 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -1,4 +1,4 @@
-from typing import TypedDict, Optional, Union, Tuple
+from typing import TypedDict, Optional, Union, Tuple, Dict
from utils.dataset import ClipClasses, ClipConditions, ClipViews
@@ -15,7 +15,7 @@ class DatasetConfiguration(TypedDict):
train_size: int
num_sampled_frames: int
discard_threshold: int
- selector: Optional[dict[str, Union[ClipClasses, ClipConditions, ClipViews]]]
+ selector: Optional[Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]]
num_input_channels: int
frame_size: Tuple[int, int]
cache_on: bool
diff --git a/utils/dataset.py b/utils/dataset.py
index 0a33693..e691157 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -1,7 +1,7 @@
import os
import random
import re
-from typing import Optional, NewType, Union, List, Tuple
+from typing import Optional, NewType, Union, List, Tuple, Set, Dict
import numpy as np
import torch
@@ -11,9 +11,9 @@ from sklearn.preprocessing import LabelEncoder
from torch.utils import data
from tqdm import tqdm
-ClipClasses = NewType('ClipClasses', set[str])
-ClipConditions = NewType('ClipConditions', set[str])
-ClipViews = NewType('ClipViews', set[str])
+ClipClasses = NewType('ClipClasses', Set[str])
+ClipConditions = NewType('ClipConditions', Set[str])
+ClipViews = NewType('ClipViews', Set[str])
class CASIAB(data.Dataset):
@@ -26,7 +26,7 @@ class CASIAB(data.Dataset):
train_size: int = 74,
num_sampled_frames: int = 30,
discard_threshold: int = 15,
- selector: Optional[dict[
+ selector: Optional[Dict[
str, Union[ClipClasses, ClipConditions, ClipViews]
]] = None,
num_input_channels: int = 3,
@@ -75,12 +75,12 @@ class CASIAB(data.Dataset):
self.views: np.ndarray[np.str_]
# Labels, classes, conditions and views in dataset,
# set of three attributes above
- self.metadata = dict[str, List[np.int64, str]]
+ self.metadata = Dict[str, List[np.int64, str]]
# Dictionaries for indexing frames and frame names by clip name
# and chip path when cache is on
- self._cached_clips_frame_names: Optional[dict[str, List[str]]] = None
- self._cached_clips: Optional[dict[str, torch.Tensor]] = None
+ self._cached_clips_frame_names: Optional[Dict[str, List[str]]] = None
+ self._cached_clips: Optional[Dict[str, torch.Tensor]] = None
# Video clip directory names
self._clip_names: List[str] = []
@@ -170,7 +170,7 @@ class CASIAB(data.Dataset):
def __getitem__(
self,
index: int
- ) -> dict[str, Union[np.int64, str, torch.Tensor]]:
+ ) -> Dict[str, Union[np.int64, str, torch.Tensor]]:
label = self.labels[index]
condition = self.conditions[index]
view = self.views[index]