diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/models/model.py b/models/model.py index 4deced0..725988a 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,5 @@ import os -from typing import Union, Optional, Tuple, List +from typing import Union, Optional, Tuple, List, Dict import numpy as np import torch @@ -166,7 +166,7 @@ class Model: popped_keys=['root_dir', 'cache_on'] ) self.log_name = '_'.join((self.log_name, self._dataset_sig)) - config: dict = dataset_config.copy() + config: Dict = dataset_config.copy() name = config.pop('name') if name == 'CASIA-B': return CASIAB(**config, is_train=self.is_train) @@ -180,7 +180,7 @@ class Model: dataset: Union[CASIAB], dataloader_config: DataloaderConfiguration ) -> DataLoader: - config: dict = dataloader_config.copy() + config: Dict = dataloader_config.copy() if self.is_train: (self.pr, self.k) = config.pop('batch_size') self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k))) @@ -195,9 +195,9 @@ class Model: def _batch_splitter( self, - batch: List[dict[str, Union[np.int64, str, torch.Tensor]]] - ) -> Tuple[dict[str, Union[List[str], torch.Tensor]], - dict[str, Union[List[str], torch.Tensor]]]: + batch: List[Dict[str, Union[np.int64, str, torch.Tensor]]] + ) -> Tuple[Dict[str, Union[List[str], torch.Tensor]], + Dict[str, Union[List[str], torch.Tensor]]]: """ Disentanglement need two random conditions, this function will split pr * k * 2 samples to 2 dicts each containing pr * k @@ -211,7 +211,7 @@ class Model: return default_collate(_batch[0]), default_collate(_batch[1]) def _make_signature(self, - config: dict, + config: Dict, popped_keys: Optional[List] = None) -> str: _config = config.copy() if popped_keys: |