import glob
import os
from typing import Tuple

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

DEVICE = torch.device('cuda')
BATCH_SIZE = 5

RAW_VIDEO_PATH = os.path.join('data', 'CASIA-B-RAW', 'video')
OUTPUT_PATH = '/tmp/CASIA-B-MRCNN-V2'
SCORE_THRESHOLD = 0.9
BOX_RATIO_THRESHOLD = (1.25, 5)
MASK_BOX_RATIO = 1.7


class CASIABClip(Dataset):

    def __init__(self, filename):
        super().__init__()
        video, *_ = torchvision.io.read_video(filename, pts_unit='sec')
        self.frames = video.permute(0, 3, 1, 2) / 255

    def __getitem__(self, index) -> Tuple[int, torch.Tensor]:
        return index, self.frames[index]

    def __len__(self) -> int:
        return len(self.frames)


model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = model.to(DEVICE)
model.eval()


def result_handler(frame_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    for (box, label, score, mask) in zip(*result.values()):
        x0, y0, x1, y1 = box
        height, width = y1 - y0, x1 - x0
        if (BOX_RATIO_THRESHOLD[0] < height / width < BOX_RATIO_THRESHOLD[1]) \
                and (score > SCORE_THRESHOLD and label == 1):
            mask_height_offset_0 = (0.043 * height) / 2
            mask_height_offset_1 = (0.027 * height) / 2
            mask_y0 = (y0 - mask_height_offset_0).floor().int()
            mask_y1 = (y1 + mask_height_offset_1).ceil().int()
            mask_half_width = ((mask_y1 - mask_y0) / MASK_BOX_RATIO) / 2
            mask_xc = (x0 + x1) / 2
            mask_x0 = (mask_xc - mask_half_width).floor().int()
            mask_x1 = (mask_xc + mask_half_width).ceil().int()

            # Skip incomplete frames
            if (height < 64 or width < 64 / BOX_RATIO_THRESHOLD[1]) \
                    or (mask_x0 < 0 or mask_x1 > 320) \
                    or (mask_y0 < 0 or mask_y1 > 240):
                continue

            cropped_frame = frame_[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
            cropped_mask = mask[:, mask_y0:mask_y1 + 1, mask_x0:mask_x1 + 1]
            filtered_frame = cropped_frame * cropped_mask

            return cropped_mask, filtered_frame


SIL_PATH = os.path.join(OUTPUT_PATH, 'SIL')
SEG_PATH = os.path.join(OUTPUT_PATH, 'SEG')
if not os.path.exists(SIL_PATH):
    os.makedirs(SIL_PATH)
if not os.path.exists(SEG_PATH):
    os.makedirs(SEG_PATH)
RAW_VIDEO_REGEX = os.path.join(RAW_VIDEO_PATH, '*-*-*-*.avi')
for clip_filename in sorted(glob.glob(RAW_VIDEO_REGEX)):
    clip_name, _ = os.path.splitext(os.path.basename(clip_filename))
    clip_sil_dir = os.path.join(SIL_PATH, clip_name)
    clip_seg_dir = os.path.join(SEG_PATH, clip_name)
    if os.path.exists(clip_sil_dir):
        if len(os.listdir(clip_sil_dir)) != 0:
            continue
    else:
        os.mkdir(clip_sil_dir)
    if os.path.exists(clip_seg_dir):
        if len(os.listdir(clip_seg_dir)) != 0:
            continue
    else:
        os.mkdir(clip_seg_dir)

    clip = CASIABClip(clip_filename)
    clip_loader = DataLoader(clip, batch_size=BATCH_SIZE, pin_memory=True)

    with torch.no_grad():
        for frame_ids, frames in tqdm(
                clip_loader, desc=clip_name, unit='batch'
        ):
            frames = frames.to(DEVICE)
            for frame_id, frame, result in zip(
                    frame_ids, frames, model(frames)
            ):
                if len(result['boxes']) == 0:
                    continue
                if processed := result_handler(frame):
                    sil, seg = processed
                    frame_basename = f'{frame_id:04d}.png'
                    sil_filename = os.path.join(clip_sil_dir, frame_basename)
                    seg_filename = os.path.join(clip_seg_dir, frame_basename)
                    torchvision.utils.save_image(sil, sil_filename)
                    torchvision.utils.save_image(seg, seg_filename)