summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-17 17:02:50 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-17 17:02:50 +0800
commit1deebf6b25ea2885609f43b56316fd4be1303381 (patch)
tree45bc18f4583b23c6f3913fef4a2229ad4c616dc1
parent5657dd650a8fffab9c8e3096a65c3cd94a5c42f4 (diff)
Add new preprocess script
-rw-r--r--preprocess.py105
1 files changed, 105 insertions, 0 deletions
diff --git a/preprocess.py b/preprocess.py
new file mode 100644
index 0000000..b7a81a3
--- /dev/null
+++ b/preprocess.py
@@ -0,0 +1,105 @@
+import glob
+import os
+
+import torch
+import torchvision
+from torch.utils.data import Dataset, DataLoader
+from tqdm import tqdm
+
+DEVICE = torch.device('cuda')
+BATCH_SIZE = 5
+
+RAW_VIDEO_PATH = os.path.join('data', 'CASIA-B-RAW', 'video')
+OUTPUT_PATH = '/tmp/CASIA-B-MRCNN-V2'
+SCORE_THRESHOLD = 0.9
+BOX_RATIO_THRESHOLD = (1.25, 5)
+MASK_BOX_RATIO = 1.7
+
+
+class CASIABClip(Dataset):
+
+ def __init__(self, filename) -> None:
+ super().__init__()
+ video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
+ self.frames = video.permute(0, 3, 1, 2) / 255
+
+ def __getitem__(self, index) -> tuple[int, torch.Tensor]:
+ return index, self.frames[index]
+
+ def __len__(self) -> int:
+ return len(self.frames)
+
+
+model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
+model = model.to(DEVICE)
+model.eval()
+
+
+def result_handler(frame_: torch.Tensor):
+ for (box, label, score, mask) in zip(*result.values()):
+ x0, y0, x1, y1 = box
+ height, width = y1 - y0, x1 - x0
+ if (BOX_RATIO_THRESHOLD[0] < height / width < BOX_RATIO_THRESHOLD[1]) \
+ and (score > SCORE_THRESHOLD and label == 1):
+ mask_height_offset_0 = (0.043 * height) / 2
+ mask_height_offset_1 = (0.027 * height) / 2
+ mask_y0 = (y0 - mask_height_offset_0).floor().int()
+ mask_y1 = (y1 + mask_height_offset_1).ceil().int()
+ mask_half_width = ((mask_y1 - mask_y0) / MASK_BOX_RATIO) / 2
+ mask_xc = (x0 + x1) / 2
+ mask_x0 = (mask_xc - mask_half_width).floor().int()
+ mask_x1 = (mask_xc + mask_half_width).ceil().int()
+
+ # Skip incomplete frames
+ if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240:
+ continue
+
+ cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
+ cropped_mask = mask[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
+ filtered_frame = cropped_frame * cropped_mask
+
+ return cropped_mask, filtered_frame
+
+
+SIL_PATH = os.path.join(OUTPUT_PATH, 'SIL')
+SEG_PATH = os.path.join(OUTPUT_PATH, 'SEG')
+if not os.path.exists(SIL_PATH):
+ os.makedirs(SIL_PATH)
+if not os.path.exists(SEG_PATH):
+ os.makedirs(SEG_PATH)
+RAW_VIDEO_REGEX = os.path.join(RAW_VIDEO_PATH, '*-*-*-*.avi')
+for clip_filename in sorted(glob.glob(RAW_VIDEO_REGEX)):
+ clip_name, _ = os.path.splitext(os.path.basename(clip_filename))
+ clip_sil_dir = os.path.join(SIL_PATH, clip_name)
+ clip_seg_dir = os.path.join(SEG_PATH, clip_name)
+ if os.path.exists(clip_sil_dir):
+ if len(os.listdir(clip_sil_dir)) != 0:
+ continue
+ else:
+ os.mkdir(clip_sil_dir)
+ if os.path.exists(clip_seg_dir):
+ if len(os.listdir(clip_seg_dir)) != 0:
+ continue
+ else:
+ os.mkdir(clip_seg_dir)
+
+ clip = CASIABClip(clip_filename)
+ clip_loader = DataLoader(clip, batch_size=BATCH_SIZE, pin_memory=True)
+
+ with torch.no_grad():
+ for frame_ids, frames in tqdm(
+ clip_loader, desc=clip_name, unit='batch'
+ ):
+ frames = frames.to(DEVICE)
+ for frame_id, frame, result in zip(
+ frame_ids, frames, model(frames)
+ ):
+ if len(result['boxes']) == 0:
+ continue
+ if processed := result_handler(frame):
+ sil, seg = processed
+ frame_basename = f'{frame_id:04d}.png'
+ sil_filename = os.path.join(clip_sil_dir, frame_basename)
+ seg_filename = os.path.join(clip_seg_dir, frame_basename)
+ torchvision.utils.save_image(sil, sil_filename)
+ torchvision.utils.save_image(seg, seg_filename)