summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-23 00:43:20 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-23 00:46:08 +0800
commita040400d7caa267d4bfbe8e5520568806f92b3d4 (patch)
tree708fe30d67183431ceb5427b556bfe7d741913dd /models/model.py
parentf8fe82326479d4ea62d7da3dd8f25f1c09cf11bc (diff)
Type hint fixes
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index 8992914..1dd195e 100644
--- a/models/model.py
+++ b/models/model.py
@@ -220,7 +220,7 @@ class Model:
def predict_all(
self,
- iters: tuple[int],
+ iters: Tuple[int],
dataset_config: DatasetConfiguration,
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
@@ -266,7 +266,7 @@ class Model:
return self._evaluate(gallery_samples, probe_samples)
- def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
+ def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
feature = self.rgb_pn(clip).detach()
@@ -326,7 +326,7 @@ class Model:
def _load_pretrained(
self,
- iters: tuple[int],
+ iters: Tuple[int],
dataset_config: DatasetConfiguration,
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]