summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--preprocess.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/preprocess.py b/preprocess.py
index b7a81a3..8af58f2 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -18,7 +18,7 @@ MASK_BOX_RATIO = 1.7
class CASIABClip(Dataset):
- def __init__(self, filename) -> None:
+ def __init__(self, filename):
super().__init__()
video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
self.frames = video.permute(0, 3, 1, 2) / 255
@@ -35,7 +35,7 @@ model = model.to(DEVICE)
model.eval()
-def result_handler(frame_: torch.Tensor):
+def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
for (box, label, score, mask) in zip(*result.values()):
x0, y0, x1, y1 = box
height, width = y1 - y0, x1 - x0
@@ -51,7 +51,9 @@ def result_handler(frame_: torch.Tensor):
mask_x1 = (mask_xc + mask_half_width).ceil().int()
# Skip incomplete frames
- if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240:
+ if (height < 64 or width < 64 / MASK_BOX_RATIO) \
+ or (mask_x0 < 0 or mask_x1 > 320) \
+ or (mask_y0 < 0 or mask_y1 > 240):
continue
cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]