From 1deebf6b25ea2885609f43b56316fd4be1303381 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Wed, 17 Feb 2021 17:02:50 +0800
Subject: Add new preprocess script

---
 preprocess.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 105 insertions(+)
 create mode 100644 preprocess.py

(limited to 'preprocess.py')

diff --git a/preprocess.py b/preprocess.py
new file mode 100644
index 0000000..b7a81a3
--- /dev/null
+++ b/preprocess.py
@@ -0,0 +1,105 @@
+import glob
+import os
+
+import torch
+import torchvision
+from torch.utils.data import Dataset, DataLoader
+from tqdm import tqdm
+
+DEVICE = torch.device('cuda')
+BATCH_SIZE = 5
+
+RAW_VIDEO_PATH = os.path.join('data', 'CASIA-B-RAW', 'video')
+OUTPUT_PATH = '/tmp/CASIA-B-MRCNN-V2'
+SCORE_THRESHOLD = 0.9
+BOX_RATIO_THRESHOLD = (1.25, 5)
+MASK_BOX_RATIO = 1.7
+
+
+class CASIABClip(Dataset):
+
+    def __init__(self, filename) -> None:
+        super().__init__()
+        video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
+        self.frames = video.permute(0, 3, 1, 2) / 255
+
+    def __getitem__(self, index) -> tuple[int, torch.Tensor]:
+        return index, self.frames[index]
+
+    def __len__(self) -> int:
+        return len(self.frames)
+
+
+model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
+model = model.to(DEVICE)
+model.eval()
+
+
+def result_handler(frame_: torch.Tensor):
+    for (box, label, score, mask) in zip(*result.values()):
+        x0, y0, x1, y1 = box
+        height, width = y1 - y0, x1 - x0
+        if (BOX_RATIO_THRESHOLD[0] < height / width < BOX_RATIO_THRESHOLD[1]) \
+                and (score > SCORE_THRESHOLD and label == 1):
+            mask_height_offset_0 = (0.043 * height) / 2
+            mask_height_offset_1 = (0.027 * height) / 2
+            mask_y0 = (y0 - mask_height_offset_0).floor().int()
+            mask_y1 = (y1 + mask_height_offset_1).ceil().int()
+            mask_half_width = ((mask_y1 - mask_y0) / MASK_BOX_RATIO) / 2
+            mask_xc = (x0 + x1) / 2
+            mask_x0 = (mask_xc - mask_half_width).floor().int()
+            mask_x1 = (mask_xc + mask_half_width).ceil().int()
+
+            # Skip incomplete frames
+            if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240:
+                continue
+
+            cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
+            cropped_mask = mask[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
+            filtered_frame = cropped_frame * cropped_mask
+
+            return cropped_mask, filtered_frame
+
+
+SIL_PATH = os.path.join(OUTPUT_PATH, 'SIL')
+SEG_PATH = os.path.join(OUTPUT_PATH, 'SEG')
+if not os.path.exists(SIL_PATH):
+    os.makedirs(SIL_PATH)
+if not os.path.exists(SEG_PATH):
+    os.makedirs(SEG_PATH)
+RAW_VIDEO_REGEX = os.path.join(RAW_VIDEO_PATH, '*-*-*-*.avi')
+for clip_filename in sorted(glob.glob(RAW_VIDEO_REGEX)):
+    clip_name, _ = os.path.splitext(os.path.basename(clip_filename))
+    clip_sil_dir = os.path.join(SIL_PATH, clip_name)
+    clip_seg_dir = os.path.join(SEG_PATH, clip_name)
+    if os.path.exists(clip_sil_dir):
+        if len(os.listdir(clip_sil_dir)) != 0:
+            continue
+    else:
+        os.mkdir(clip_sil_dir)
+    if os.path.exists(clip_seg_dir):
+        if len(os.listdir(clip_seg_dir)) != 0:
+            continue
+    else:
+        os.mkdir(clip_seg_dir)
+
+    clip = CASIABClip(clip_filename)
+    clip_loader = DataLoader(clip, batch_size=BATCH_SIZE, pin_memory=True)
+
+    with torch.no_grad():
+        for frame_ids, frames in tqdm(
+                clip_loader, desc=clip_name, unit='batch'
+        ):
+            frames = frames.to(DEVICE)
+            for frame_id, frame, result in zip(
+                    frame_ids, frames, model(frames)
+            ):
+                if len(result['boxes']) == 0:
+                    continue
+                if processed := result_handler(frame):
+                    sil, seg = processed
+                    frame_basename = f'{frame_id:04d}.png'
+                    sil_filename = os.path.join(clip_sil_dir, frame_basename)
+                    seg_filename = os.path.join(clip_seg_dir, frame_basename)
+                    torchvision.utils.save_image(sil, sil_filename)
+                    torchvision.utils.save_image(seg, seg_filename)
-- 
cgit v1.2.3


From 641d19c1ebdf44486de139fadeff3276aecdf284 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Wed, 17 Feb 2021 17:38:03 +0800
Subject: Fix type hints and add constrains to height and width

---
 preprocess.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

(limited to 'preprocess.py')

diff --git a/preprocess.py b/preprocess.py
index b7a81a3..8af58f2 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -18,7 +18,7 @@ MASK_BOX_RATIO = 1.7
 
 class CASIABClip(Dataset):
 
-    def __init__(self, filename) -> None:
+    def __init__(self, filename):
         super().__init__()
         video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
         self.frames = video.permute(0, 3, 1, 2) / 255
@@ -35,7 +35,7 @@ model = model.to(DEVICE)
 model.eval()
 
 
-def result_handler(frame_: torch.Tensor):
+def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
     for (box, label, score, mask) in zip(*result.values()):
         x0, y0, x1, y1 = box
         height, width = y1 - y0, x1 - x0
@@ -51,7 +51,9 @@ def result_handler(frame_: torch.Tensor):
             mask_x1 = (mask_xc + mask_half_width).ceil().int()
 
             # Skip incomplete frames
-            if mask_x0 < 0 or mask_x1 > 320 or mask_y0 < 0 or mask_y1 > 240:
+            if (height < 64 or width < 64 / MASK_BOX_RATIO) \
+                    or (mask_x0 < 0 or mask_x1 > 320) \
+                    or (mask_y0 < 0 or mask_y1 > 240):
                 continue
 
             cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
-- 
cgit v1.2.3