summaryrefslogtreecommitdiff
path: root/preprocess.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-17 20:27:33 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-17 20:27:33 +0800
commit2988c1b9afd4e869b629a8629abedbf63d2452aa (patch)
tree39c4ca790e4ae0cd79795dabb669133136eaec03 /preprocess.py
parentef41498844a41e275f5d3501307d6ee3359c5ead (diff)
parent31af04b70fcc0e7b46cbe18a52d150eb4e274f0e (diff)
Merge branch 'python3.8' into data_parallel_py3.8
Diffstat (limited to 'preprocess.py')
-rw-r--r--preprocess.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/preprocess.py b/preprocess.py
index 8af58f2..894c7c1 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -1,5 +1,6 @@
import glob
import os
+from typing import Tuple
import torch
import torchvision
@@ -23,7 +24,7 @@ class CASIABClip(Dataset):
video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
self.frames = video.permute(0, 3, 1, 2) / 255
- def __getitem__(self, index) -> tuple[int, torch.Tensor]:
+ def __getitem__(self, index) -> Tuple[int, torch.Tensor]:
return index, self.frames[index]
def __len__(self) -> int:
@@ -35,7 +36,7 @@ model = model.to(DEVICE)
model.eval()
-def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+def result_handler(frame_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
for (box, label, score, mask) in zip(*result.values()):
x0, y0, x1, y1 = box
height, width = y1 - y0, x1 - x0