summaryrefslogtreecommitdiff
path: root/preprocess.py
blob: eef59ba9045ed390d68b76de0f08ebaf008da97e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import glob
import os
from typing import Tuple

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

DEVICE = torch.device('cuda')
BATCH_SIZE = 5

RAW_VIDEO_PATH = os.path.join('data', 'CASIA-B-RAW', 'video')
OUTPUT_PATH = '/tmp/CASIA-B-MRCNN-V2'
SCORE_THRESHOLD = 0.9
BOX_RATIO_THRESHOLD = (1.25, 5)
MASK_BOX_RATIO = 1.7


class CASIABClip(Dataset):

    def __init__(self, filename):
        super().__init__()
        video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
        self.frames = video.permute(0, 3, 1, 2) / 255

    def __getitem__(self, index) -> Tuple[int, torch.Tensor]:
        return index, self.frames[index]

    def __len__(self) -> int:
        return len(self.frames)


model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = model.to(DEVICE)
model.eval()


def result_handler(frame_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    for (box, label, score, mask) in zip(*result.values()):
        x0, y0, x1, y1 = box
        height, width = y1 - y0, x1 - x0
        if (BOX_RATIO_THRESHOLD[0] < height / width < BOX_RATIO_THRESHOLD[1]) \
                and (score > SCORE_THRESHOLD and label == 1):
            mask_height_offset_0 = (0.043 * height) / 2
            mask_height_offset_1 = (0.027 * height) / 2
            mask_y0 = (y0 - mask_height_offset_0).floor().int()
            mask_y1 = (y1 + mask_height_offset_1).ceil().int()
            mask_half_width = ((mask_y1 - mask_y0) / MASK_BOX_RATIO) / 2
            mask_xc = (x0 + x1) / 2
            mask_x0 = (mask_xc - mask_half_width).floor().int()
            mask_x1 = (mask_xc + mask_half_width).ceil().int()

            # Skip incomplete frames
            if (height < 64 or width < 64 / BOX_RATIO_THRESHOLD[1]) \
                    or (mask_x0 < 0 or mask_x1 > 320) \
                    or (mask_y0 < 0 or mask_y1 > 240):
                continue

            cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
            cropped_mask = mask[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
            filtered_frame = cropped_frame * cropped_mask

            return cropped_mask, filtered_frame


SIL_PATH = os.path.join(OUTPUT_PATH, 'SIL')
SEG_PATH = os.path.join(OUTPUT_PATH, 'SEG')
if not os.path.exists(SIL_PATH):
    os.makedirs(SIL_PATH)
if not os.path.exists(SEG_PATH):
    os.makedirs(SEG_PATH)
RAW_VIDEO_REGEX = os.path.join(RAW_VIDEO_PATH, '*-*-*-*.avi')
for clip_filename in sorted(glob.glob(RAW_VIDEO_REGEX)):
    clip_name, _ = os.path.splitext(os.path.basename(clip_filename))
    clip_sil_dir = os.path.join(SIL_PATH, clip_name)
    clip_seg_dir = os.path.join(SEG_PATH, clip_name)
    if os.path.exists(clip_sil_dir):
        if len(os.listdir(clip_sil_dir)) != 0:
            continue
    else:
        os.mkdir(clip_sil_dir)
    if os.path.exists(clip_seg_dir):
        if len(os.listdir(clip_seg_dir)) != 0:
            continue
    else:
        os.mkdir(clip_seg_dir)

    clip = CASIABClip(clip_filename)
    clip_loader = DataLoader(clip, batch_size=BATCH_SIZE, pin_memory=True)

    with torch.no_grad():
        for frame_ids, frames in tqdm(
                clip_loader, desc=clip_name, unit='batch'
        ):
            frames = frames.to(DEVICE)
            for frame_id, frame, result in zip(
                    frame_ids, frames, model(frames)
            ):
                if len(result['boxes']) == 0:
                    continue
                if processed := result_handler(frame):
                    sil, seg = processed
                    frame_basename = f'{frame_id:04d}.png'
                    sil_filename = os.path.join(clip_sil_dir, frame_basename)
                    seg_filename = os.path.join(clip_seg_dir, frame_basename)
                    torchvision.utils.save_image(sil, sil_filename)
                    torchvision.utils.save_image(seg, seg_filename)