diff options
-rw-r--r-- | preprocess.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/preprocess.py b/preprocess.py index b7a81a3..8af58f2 100644 --- a/preprocess.py +++ b/preprocess.py @@ -18,7 +18,7 @@ MASK_BOX_RATIO = 1.7 class CASIABClip(Dataset): - def __init__(self, filename) -> None: + def __init__(self, filename): super().__init__() video, *_ = torchvision.io.read_video(filename, pts_unit='sec') self.frames = video.permute(0, 3, 1, 2) / 255 @@ -35,7 +35,7 @@ model = model.to(DEVICE) model.eval() -def result_handler(frame_: torch.Tensor): +def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: for (box, label, score, mask) in zip(*result.values()): x0, y0, x1, y1 = box height, width = y1 - y0, x1 - x0 @@ -51,7 +51,9 @@ def result_handler(frame_: torch.Tensor): mask_x1 = (mask_xc + mask_half_width).ceil().int() # Skip incomplete frames - if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240: + if (height < 64 or width < 64 / MASK_BOX_RATIO) \ + or (mask_x0 < 0 or mask_x1 > 320) \ + or (mask_y0 < 0 or mask_y1 > 240): continue cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1] |