summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:59:40 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-29 21:06:57 +0800
commit4ee31c0ec4038f9af46959248b90d31569b473d1 (patch)
tree35624ad483837d73afcc72793be13c475caf4e6b /utils
parentb8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd (diff)
Add type hint for new label (numpy.int64)
Diffstat (limited to 'utils')
-rw-r--r--utils/dataset.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
index 050ac03..bb4d762 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -75,7 +75,7 @@ class CASIAB(data.Dataset):
self.views: np.ndarray[np.str_]
# Labels, classes, conditions and views in dataset,
# set of three attributes above
- self.metadata = dict[str, list[str]]
+ self.metadata = dict[str, list[np.int64, str]]
# Dictionaries for indexing frames and frame names by clip name
# and chip path when cache is on
@@ -167,7 +167,10 @@ class CASIAB(data.Dataset):
def __len__(self) -> int:
return len(self.labels)
- def __getitem__(self, index: int) -> dict[str, Union[str, torch.Tensor]]:
+ def __getitem__(
+ self,
+ index: int
+ ) -> dict[str, Union[np.int64, str, torch.Tensor]]:
label = self.labels[index]
condition = self.conditions[index]
view = self.views[index]