diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-17 17:38:03 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-17 17:38:03 +0800 | 
| commit | 641d19c1ebdf44486de139fadeff3276aecdf284 (patch) | |
| tree | 2da573b33771786ec19a91dba755065a04fd4489 /preprocess.py | |
| parent | 1deebf6b25ea2885609f43b56316fd4be1303381 (diff) | |
Fix type hints and add constrains to height and width
Diffstat (limited to 'preprocess.py')
| -rw-r--r-- | preprocess.py | 8 | 
1 files changed, 5 insertions, 3 deletions
| diff --git a/preprocess.py b/preprocess.py index b7a81a3..8af58f2 100644 --- a/preprocess.py +++ b/preprocess.py @@ -18,7 +18,7 @@ MASK_BOX_RATIO = 1.7  class CASIABClip(Dataset): -    def __init__(self, filename) -> None: +    def __init__(self, filename):          super().__init__()          video, *_ = torchvision.io.read_video(filename, pts_unit='sec')          self.frames = video.permute(0, 3, 1, 2) / 255 @@ -35,7 +35,7 @@ model = model.to(DEVICE)  model.eval() -def result_handler(frame_: torch.Tensor): +def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:      for (box, label, score, mask) in zip(*result.values()):          x0, y0, x1, y1 = box          height, width = y1 - y0, x1 - x0 @@ -51,7 +51,9 @@ def result_handler(frame_: torch.Tensor):              mask_x1 = (mask_xc + mask_half_width).ceil().int()              # Skip incomplete frames -            if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240: +            if (height < 64 or width < 64 / MASK_BOX_RATIO) \ +                    or (mask_x0 < 0 or mask_x1 > 320) \ +                    or (mask_y0 < 0 or mask_y1 > 240):                  continue              cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1] | 
