summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index 8992914..1dd195e 100644
--- a/models/model.py
+++ b/models/model.py
@@ -220,7 +220,7 @@ class Model:
def predict_all(
self,
- iters: tuple[int],
+ iters: Tuple[int],
dataset_config: DatasetConfiguration,
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
@@ -266,7 +266,7 @@ class Model:
return self._evaluate(gallery_samples, probe_samples)
- def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
+ def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
feature = self.rgb_pn(clip).detach()
@@ -326,7 +326,7 @@ class Model:
def _load_pretrained(
self,
- iters: tuple[int],
+ iters: Tuple[int],
dataset_config: DatasetConfiguration,
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]