diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-29 20:59:40 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-29 21:06:57 +0800 |
commit | 4ee31c0ec4038f9af46959248b90d31569b473d1 (patch) | |
tree | 35624ad483837d73afcc72793be13c475caf4e6b /utils | |
parent | b8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd (diff) |
Add type hint for new label (numpy.int64)
Diffstat (limited to 'utils')
-rw-r--r-- | utils/dataset.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/utils/dataset.py b/utils/dataset.py index 050ac03..bb4d762 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -75,7 +75,7 @@ class CASIAB(data.Dataset): self.views: np.ndarray[np.str_] # Labels, classes, conditions and views in dataset, # set of three attributes above - self.metadata = dict[str, list[str]] + self.metadata = dict[str, list[np.int64, str]] # Dictionaries for indexing frames and frame names by clip name # and chip path when cache is on @@ -167,7 +167,10 @@ class CASIAB(data.Dataset): def __len__(self) -> int: return len(self.labels) - def __getitem__(self, index: int) -> dict[str, Union[str, torch.Tensor]]: + def __getitem__( + self, + index: int + ) -> dict[str, Union[np.int64, str, torch.Tensor]]: label = self.labels[index] condition = self.conditions[index] view = self.views[index] |