summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/models/model.py b/models/model.py
index 54f3441..80d4499 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,5 @@
import os
-from typing import Union, Optional
+from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
import torch
@@ -166,7 +166,7 @@ class Model:
popped_keys=['root_dir', 'cache_on']
)
self.log_name = '_'.join((self.log_name, self._dataset_sig))
- config: dict = dataset_config.copy()
+ config: Dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
return CASIAB(**config, is_train=self.is_train)
@@ -180,7 +180,7 @@ class Model:
dataset: Union[CASIAB],
dataloader_config: DataloaderConfiguration
) -> DataLoader:
- config: dict = dataloader_config.copy()
+ config: Dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
@@ -195,9 +195,9 @@ class Model:
def _batch_splitter(
self,
- batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
- ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
- dict[str, Union[list[str], torch.Tensor]]]:
+ batch: List[Dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> Tuple[Dict[str, Union[List[str], torch.Tensor]],
+ Dict[str, Union[List[str], torch.Tensor]]]:
"""
Disentanglement need two random conditions, this function will
split pr * k * 2 samples to 2 dicts each containing pr * k
@@ -211,8 +211,8 @@ class Model:
return default_collate(_batch[0]), default_collate(_batch[1])
def _make_signature(self,
- config: dict,
- popped_keys: Optional[list] = None) -> str:
+ config: Dict,
+ popped_keys: Optional[List] = None) -> str:
_config = config.copy()
if popped_keys:
for key in popped_keys:
@@ -220,14 +220,14 @@ class Model:
return self._gen_sig(list(_config.values()))
- def _gen_sig(self, values: Union[tuple, list, set, str, int, float]) -> str:
+ def _gen_sig(self, values: Union[Tuple, List, Set, str, int, float]) -> str:
strings = []
for v in values:
if isinstance(v, str):
strings.append(v)
- elif isinstance(v, (tuple, list, set)):
+ elif isinstance(v, (Tuple, List, Set)):
strings.append(self._gen_sig(v))
- elif isinstance(v, dict):
+ elif isinstance(v, Dict):
strings.append(self._gen_sig(list(v.values())))
else:
strings.append(str(v))