From 1deebf6b25ea2885609f43b56316fd4be1303381 Mon Sep 17 00:00:00 2001 From: Jordan Gong <jordan.gong@protonmail.com> Date: Wed, 17 Feb 2021 17:02:50 +0800 Subject: Add new preprocess script --- preprocess.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 preprocess.py (limited to 'preprocess.py') diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..b7a81a3 --- /dev/null +++ b/preprocess.py @@ -0,0 +1,105 @@ +import glob +import os + +import torch +import torchvision +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +DEVICE = torch.device('cuda') +BATCH_SIZE = 5 + +RAW_VIDEO_PATH = os.path.join('data', 'CASIA-B-RAW', 'video') +OUTPUT_PATH = '/tmp/CASIA-B-MRCNN-V2' +SCORE_THRESHOLD = 0.9 +BOX_RATIO_THRESHOLD = (1.25, 5) +MASK_BOX_RATIO = 1.7 + + +class CASIABClip(Dataset): + + def __init__(self, filename) -> None: + super().__init__() + video, *_ = torchvision.io.read_video(filename, pts_unit='sec') + self.frames = video.permute(0, 3, 1, 2) / 255 + + def __getitem__(self, index) -> tuple[int, torch.Tensor]: + return index, self.frames[index] + + def __len__(self) -> int: + return len(self.frames) + + +model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) +model = model.to(DEVICE) +model.eval() + + +def result_handler(frame_: torch.Tensor): + for (box, label, score, mask) in zip(*result.values()): + x0, y0, x1, y1 = box + height, width = y1 - y0, x1 - x0 + if (BOX_RATIO_THRESHOLD[0] < height / width < BOX_RATIO_THRESHOLD[1]) \ + and (score > SCORE_THRESHOLD and label == 1): + mask_height_offset_0 = (0.043 * height) / 2 + mask_height_offset_1 = (0.027 * height) / 2 + mask_y0 = (y0 - mask_height_offset_0).floor().int() + mask_y1 = (y1 + mask_height_offset_1).ceil().int() + mask_half_width = ((mask_y1 - mask_y0) / MASK_BOX_RATIO) / 2 + mask_xc = (x0 + x1) / 2 + mask_x0 = (mask_xc - mask_half_width).floor().int() + mask_x1 = (mask_xc + mask_half_width).ceil().int() + + # Skip incomplete frames + if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240: + continue + + cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1] + cropped_mask = mask[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1] + filtered_frame = cropped_frame * cropped_mask + + return cropped_mask, filtered_frame + + +SIL_PATH = os.path.join(OUTPUT_PATH, 'SIL') +SEG_PATH = os.path.join(OUTPUT_PATH, 'SEG') +if not os.path.exists(SIL_PATH): + os.makedirs(SIL_PATH) +if not os.path.exists(SEG_PATH): + os.makedirs(SEG_PATH) +RAW_VIDEO_REGEX = os.path.join(RAW_VIDEO_PATH, '*-*-*-*.avi') +for clip_filename in sorted(glob.glob(RAW_VIDEO_REGEX)): + clip_name, _ = os.path.splitext(os.path.basename(clip_filename)) + clip_sil_dir = os.path.join(SIL_PATH, clip_name) + clip_seg_dir = os.path.join(SEG_PATH, clip_name) + if os.path.exists(clip_sil_dir): + if len(os.listdir(clip_sil_dir)) != 0: + continue + else: + os.mkdir(clip_sil_dir) + if os.path.exists(clip_seg_dir): + if len(os.listdir(clip_seg_dir)) != 0: + continue + else: + os.mkdir(clip_seg_dir) + + clip = CASIABClip(clip_filename) + clip_loader = DataLoader(clip, batch_size=BATCH_SIZE, pin_memory=True) + + with torch.no_grad(): + for frame_ids, frames in tqdm( + clip_loader, desc=clip_name, unit='batch' + ): + frames = frames.to(DEVICE) + for frame_id, frame, result in zip( + frame_ids, frames, model(frames) + ): + if len(result['boxes']) == 0: + continue + if processed := result_handler(frame): + sil, seg = processed + frame_basename = f'{frame_id:04d}.png' + sil_filename = os.path.join(clip_sil_dir, frame_basename) + seg_filename = os.path.join(clip_seg_dir, frame_basename) + torchvision.utils.save_image(sil, sil_filename) + torchvision.utils.save_image(seg, seg_filename) -- cgit v1.2.3 From 641d19c1ebdf44486de139fadeff3276aecdf284 Mon Sep 17 00:00:00 2001 From: Jordan Gong <jordan.gong@protonmail.com> Date: Wed, 17 Feb 2021 17:38:03 +0800 Subject: Fix type hints and add constrains to height and width --- preprocess.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'preprocess.py') diff --git a/preprocess.py b/preprocess.py index b7a81a3..8af58f2 100644 --- a/preprocess.py +++ b/preprocess.py @@ -18,7 +18,7 @@ MASK_BOX_RATIO = 1.7 class CASIABClip(Dataset): - def __init__(self, filename) -> None: + def __init__(self, filename): super().__init__() video, *_ = torchvision.io.read_video(filename, pts_unit='sec') self.frames = video.permute(0, 3, 1, 2) / 255 @@ -35,7 +35,7 @@ model = model.to(DEVICE) model.eval() -def result_handler(frame_: torch.Tensor): +def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: for (box, label, score, mask) in zip(*result.values()): x0, y0, x1, y1 = box height, width = y1 - y0, x1 - x0 @@ -51,7 +51,9 @@ def result_handler(frame_: torch.Tensor): mask_x1 = (mask_xc + mask_half_width).ceil().int() # Skip incomplete frames - if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240: + if (height < 64 or width < 64 / MASK_BOX_RATIO) \ + or (mask_x0 < 0 or mask_x1 > 320) \ + or (mask_y0 < 0 or mask_y1 > 240): continue cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1] -- cgit v1.2.3