summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/models/model.py b/models/model.py
index 1dc0f23..4deced0 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,5 @@
import os
-from typing import Union, Optional
+from typing import Union, Optional, Tuple, List
import numpy as np
import torch
@@ -195,9 +195,9 @@ class Model:
def _batch_splitter(
self,
- batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
- ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
- dict[str, Union[list[str], torch.Tensor]]]:
+ batch: List[dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> Tuple[dict[str, Union[List[str], torch.Tensor]],
+ dict[str, Union[List[str], torch.Tensor]]]:
"""
Disentanglement need two random conditions, this function will
split pr * k * 2 samples to 2 dicts each containing pr * k
@@ -212,7 +212,7 @@ class Model:
def _make_signature(self,
config: dict,
- popped_keys: Optional[list] = None) -> str:
+ popped_keys: Optional[List] = None) -> str:
_config = config.copy()
if popped_keys:
for key in popped_keys:
@@ -220,12 +220,12 @@ class Model:
return self._gen_sig(list(_config.values()))
- def _gen_sig(self, values: Union[tuple, list, str, int, float]) -> str:
+ def _gen_sig(self, values: Union[Tuple, List, str, int, float]) -> str:
strings = []
for v in values:
if isinstance(v, str):
strings.append(v)
- elif isinstance(v, (tuple, list)):
+ elif isinstance(v, (Tuple, List)):
strings.append(self._gen_sig(v))
else:
strings.append(str(v))