summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-07 19:55:00 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-07 20:16:09 +0800
commite5a73abd80578aa5e46d8d444466d1e6346ec6ec (patch)
tree9af6211c406e995a4c2129b931ad2a83bc05c0b5 /models
parent98b6e6dc3be6f88abb72e351c8f2da2b23b8ab85 (diff)
Type hint for python version lower than 3.9
Diffstat (limited to 'models')
-rw-r--r--models/model.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/models/model.py b/models/model.py
index 4deced0..725988a 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,5 @@
import os
-from typing import Union, Optional, Tuple, List
+from typing import Union, Optional, Tuple, List, Dict
import numpy as np
import torch
@@ -166,7 +166,7 @@ class Model:
popped_keys=['root_dir', 'cache_on']
)
self.log_name = '_'.join((self.log_name, self._dataset_sig))
- config: dict = dataset_config.copy()
+ config: Dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
return CASIAB(**config, is_train=self.is_train)
@@ -180,7 +180,7 @@ class Model:
dataset: Union[CASIAB],
dataloader_config: DataloaderConfiguration
) -> DataLoader:
- config: dict = dataloader_config.copy()
+ config: Dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
@@ -195,9 +195,9 @@ class Model:
def _batch_splitter(
self,
- batch: List[dict[str, Union[np.int64, str, torch.Tensor]]]
- ) -> Tuple[dict[str, Union[List[str], torch.Tensor]],
- dict[str, Union[List[str], torch.Tensor]]]:
+ batch: List[Dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> Tuple[Dict[str, Union[List[str], torch.Tensor]],
+ Dict[str, Union[List[str], torch.Tensor]]]:
"""
Disentanglement need two random conditions, this function will
split pr * k * 2 samples to 2 dicts each containing pr * k
@@ -211,7 +211,7 @@ class Model:
return default_collate(_batch[0]), default_collate(_batch[1])
def _make_signature(self,
- config: dict,
+ config: Dict,
popped_keys: Optional[List] = None) -> str:
_config = config.copy()
if popped_keys: