# Copyright (c) Alibaba, Inc. and its affiliates.

import os

import cv2
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from PIL import Image

from modelscope.outputs import OutputKeys
from modelscope.preprocessors.image import load_image
from modelscope.utils import logger as logging

logger = logging.get_logger()


class InputPadder:
    """ Pads images such that dimensions are divisible by 8 """

    def __init__(self, dims, mode='sintel'):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
        if mode == 'sintel':
            self._pad = [
                pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2,
                pad_ht - pad_ht // 2
            ]
        else:
            self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]

    def pad(self, *inputs):
        return [F.pad(x, self._pad, mode='replicate') for x in inputs]

    def unpad(self, x):
        ht, wd = x.shape[-2:]
        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
        return x[..., c[0]:c[1], c[2]:c[3]]


def numpy_to_cv2img(img_array):
    """to convert a np.array with shape(h, w) to cv2 img

    Args:
        img_array (np.array): input data

    Returns:
        cv2 img
    """
    img_array = (img_array - img_array.min()) / (
        img_array.max() - img_array.min() + 1e-5)
    img_array = (img_array * 255).astype(np.uint8)
    img_array = cv2.applyColorMap(img_array, cv2.COLORMAP_JET)
    return img_array


def draw_joints(image, np_kps, score, threshold=0.2):
    lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14]
    lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15]
    lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16]

    lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1]
    lst_left_ids_15 = [2, 3, 4, 8, 9, 10]
    lst_right_ids_15 = [5, 6, 7, 11, 12, 13]

    if np_kps.shape[0] == 17:
        lst_parent_ids = lst_parent_ids_17
        lst_left_ids = lst_left_ids_17
        lst_right_ids = lst_right_ids_17

    elif np_kps.shape[0] == 15:
        lst_parent_ids = lst_parent_ids_15
        lst_left_ids = lst_left_ids_15
        lst_right_ids = lst_right_ids_15

    for i in range(len(lst_parent_ids)):
        pid = lst_parent_ids[i]
        if i == pid:
            continue

        if (score[i] < threshold or score[1] < threshold):
            continue

        if i in lst_left_ids and pid in lst_left_ids:
            color = (0, 255, 0)
        elif i in lst_right_ids and pid in lst_right_ids:
            color = (255, 0, 0)
        else:
            color = (0, 255, 255)

        cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])),
                 (int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3)

    for i in range(np_kps.shape[0]):
        if score[i] < threshold:
            continue
        cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5,
                   (0, 0, 255), -1)


def draw_box(image, box):
    cv2.rectangle(image, (int(box[0]), int(box[1])),
                  (int(box[2]), int(box[3])), (0, 0, 255), 2)


def realtime_object_detection_bbox_vis(image, bboxes):
    for bbox in bboxes:
        cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
                      (255, 0, 0), 2)
    return image


def draw_attribute(image, box, labels):
    cv2.rectangle(image, (int(box[0]), int(box[1])),
                  (int(box[2]), int(box[3])), (0, 0, 255), 2)
    title = [
        'gender      : ', 'age         : ', 'orient      : ', 'hat         : ',
        'glass       : ', 'hand_bag    : ', 'shoulder_bag: ', 'back_pack   : ',
        'upper_wear  : ', 'lower_wear  : ', 'upper_color : ', 'lower_color : '
    ]

    clr = (np.random.randint(0, 255), np.random.randint(0, 255),
           np.random.randint(0, 255))

    point = (int(box[0] + 5), int(box[1] + 20))
    for idx, lb in enumerate(labels):
        sz = title[idx] + lb
        cv2.putText(image, f'{sz}', (point[0], point[1] + idx * 20),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 1)


def draw_keypoints(output, original_image):
    poses = np.array(output[OutputKeys.KEYPOINTS])
    scores = np.array(output[OutputKeys.SCORES])
    boxes = np.array(output[OutputKeys.BOXES])
    assert len(poses) == len(scores) and len(poses) == len(boxes)
    image = cv2.imread(original_image, -1)
    for i in range(len(poses)):
        draw_box(image, np.array(boxes[i]))
        draw_joints(image, np.array(poses[i]), np.array(scores[i]))
    return image


def draw_pedestrian_attribute(output, original_image):
    labels = np.array(output[OutputKeys.LABELS])
    boxes = np.array(output[OutputKeys.BOXES])
    assert len(labels) == len(boxes)
    image = cv2.imread(original_image, -1)
    for i in range(len(boxes)):
        draw_attribute(image, np.array(boxes[i]), labels[i])
    return image


def draw_106face_keypoints(in_path,
                           keypoints,
                           boxes,
                           scale=4.0,
                           save_path=None):
    face_contour_point_index = [
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
        20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
    ]
    left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33]
    right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42]
    left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66]
    right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75]
    nose_bridge_point_index = [51, 52, 53, 54]
    nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
    mouth_outer_point_index = [
        84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84
    ]
    mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96]

    img = cv2.imread(in_path)

    for i in range(len(boxes)):
        draw_box(img, np.array(boxes[i]))

    image = cv2.resize(img, dsize=None, fx=scale, fy=scale)

    def draw_line(point_index, image, point):
        for i in range(len(point_index) - 1):
            cur_index = point_index[i]
            next_index = point_index[i + 1]
            cur_pt = (int(point[cur_index][0] * scale),
                      int(point[cur_index][1] * scale))
            next_pt = (int(point[next_index][0] * scale),
                       int(point[next_index][1] * scale))
            cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2)

    for i in range(len(keypoints)):
        points = keypoints[i]

        draw_line(face_contour_point_index, image, points)
        draw_line(left_eye_brow_point_index, image, points)
        draw_line(right_eye_brow_point_index, image, points)
        draw_line(left_eye_point_index, image, points)
        draw_line(right_eye_point_index, image, points)
        draw_line(nose_bridge_point_index, image, points)
        draw_line(nose_contour_point_index, image, points)
        draw_line(mouth_outer_point_index, image, points)
        draw_line(mouth_inter_point_index, image, points)

        size = len(points)
        for i in range(size):
            x = int(points[i][0])
            y = int(points[i][1])
            cv2.putText(image, str(i), (int(x * scale), int(y * scale)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
            cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0),
                       cv2.FILLED)

    if save_path is not None:
        cv2.imwrite(save_path, image)

    return image


def draw_face_detection_no_lm_result(img_path, detection_result):
    bboxes = np.array(detection_result[OutputKeys.BOXES])
    scores = np.array(detection_result[OutputKeys.SCORES])
    img = cv2.imread(img_path)
    assert img is not None, f"Can't read img: {img_path}"
    for i in range(len(scores)):
        bbox = bboxes[i].astype(np.int32)
        x1, y1, x2, y2 = bbox
        score = scores[i]
        cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
        cv2.putText(
            img,
            f'{score:.2f}', (x1, y2),
            1,
            1.0, (0, 255, 0),
            thickness=1,
            lineType=8)
    print(f'Found {len(scores)} faces')
    return img


def draw_facial_expression_result(img_path, facial_expression_result):
    scores = facial_expression_result[OutputKeys.SCORES]
    labels = facial_expression_result[OutputKeys.LABELS]
    label = labels[np.argmax(scores)]
    img = cv2.imread(img_path)
    assert img is not None, f"Can't read img: {img_path}"
    cv2.putText(
        img,
        'facial expression: {}'.format(label), (10, 10),
        1,
        1.0, (0, 255, 0),
        thickness=1,
        lineType=8)
    print('facial expression: {}'.format(label))
    return img


def draw_face_attribute_result(img_path, face_attribute_result):
    scores = face_attribute_result[OutputKeys.SCORES]
    labels = face_attribute_result[OutputKeys.LABELS]
    label_gender = labels[0][np.argmax(scores[0])]
    label_age = labels[1][np.argmax(scores[1])]
    img = cv2.imread(img_path)
    assert img is not None, f"Can't read img: {img_path}"
    cv2.putText(
        img,
        'face gender: {}'.format(label_gender), (10, 10),
        1,
        1.0, (0, 255, 0),
        thickness=1,
        lineType=8)

    cv2.putText(
        img,
        'face age interval: {}'.format(label_age), (10, 40),
        1,
        1.0, (255, 0, 0),
        thickness=1,
        lineType=8)
    logger.info('face gender: {}'.format(label_gender))
    logger.info('face age interval: {}'.format(label_age))
    return img


def draw_face_detection_result(img_path, detection_result):
    bboxes = np.array(detection_result[OutputKeys.BOXES])
    kpss = np.array(detection_result[OutputKeys.KEYPOINTS])
    scores = np.array(detection_result[OutputKeys.SCORES])
    img = cv2.imread(img_path)
    assert img is not None, f"Can't read img: {img_path}"
    for i in range(len(scores)):
        bbox = bboxes[i].astype(np.int32)
        kps = kpss[i].reshape(-1, 2).astype(np.int32)
        score = scores[i]
        x1, y1, x2, y2 = bbox
        cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
        for kp in kps:
            cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1)
        cv2.putText(
            img,
            f'{score:.2f}', (x1, y2),
            1,
            1.0, (0, 255, 0),
            thickness=1,
            lineType=8)
    print(f'Found {len(scores)} faces')
    return img


def draw_card_detection_result(img_path, detection_result):

    def warp_img(src_img, kps, ratio):
        short_size = 500
        if ratio > 1:
            obj_h = short_size
            obj_w = int(obj_h * ratio)
        else:
            obj_w = short_size
            obj_h = int(obj_w / ratio)
        input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]])
        output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0],
                                 [obj_w - 1, obj_h - 1]])
        M = cv2.getPerspectiveTransform(input_pts, output_pts)
        obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h))
        return obj_img

    bboxes = np.array(detection_result[OutputKeys.BOXES])
    kpss = np.array(detection_result[OutputKeys.KEYPOINTS])
    scores = np.array(detection_result[OutputKeys.SCORES])
    img_list = []
    ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)]
    img = cv2.imread(img_path)
    img_list += [img]
    assert img is not None, f"Can't read img: {img_path}"
    for i in range(len(scores)):
        bbox = bboxes[i].astype(np.int32)
        kps = kpss[i].reshape(-1, 2).astype(np.int32)
        _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2
        _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2
        ratio = 1.59 if _w >= _h else 1 / 1.59
        card_img = warp_img(img, kps, ratio)
        img_list += [card_img]
        score = scores[i]
        x1, y1, x2, y2 = bbox
        cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4)
        for k, kp in enumerate(kps):
            cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10)
        cv2.putText(
            img,
            f'{score:.2f}', (x1, y2),
            1,
            1.0, (0, 255, 0),
            thickness=1,
            lineType=8)
    return img_list


def created_boxed_image(image_in, box):
    image = load_image(image_in)
    img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
    cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
                  (0, 255, 0), 3)
    return img


def show_video_tracking_result(video_in_path, bboxes, video_save_path):
    cap = cv2.VideoCapture(video_in_path)
    for i in range(len(bboxes)):
        box = bboxes[i]
        success, frame = cap.read()
        if success is False:
            raise Exception(video_in_path,
                            ' can not be correctly decoded by OpenCV.')
        if i == 0:
            size = (frame.shape[1], frame.shape[0])
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            video_writer = cv2.VideoWriter(video_save_path, fourcc,
                                           cap.get(cv2.CAP_PROP_FPS), size,
                                           True)
        cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0),
                      5)
        video_writer.write(frame)
    video_writer.release
    cap.release()


def show_video_object_detection_result(video_in_path, bboxes_list, labels_list,
                                       video_save_path):

    PALETTE = {
        'person': [128, 0, 0],
        'bicycle': [128, 128, 0],
        'car': [64, 0, 0],
        'motorcycle': [0, 128, 128],
        'bus': [64, 128, 0],
        'truck': [192, 128, 0],
        'traffic light': [64, 0, 128],
        'stop sign': [192, 0, 128],
    }
    from tqdm import tqdm
    import math
    cap = cv2.VideoCapture(video_in_path)
    with tqdm(total=len(bboxes_list)) as pbar:
        pbar.set_description(
            'Writing results to video: {}'.format(video_save_path))
        for i in range(len(bboxes_list)):
            bboxes = bboxes_list[i].astype(int)
            labels = labels_list[i]
            success, frame = cap.read()
            if success is False:
                raise Exception(video_in_path,
                                ' can not be correctly decoded by OpenCV.')
            if i == 0:
                size = (frame.shape[1], frame.shape[0])
                fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
                video_writer = cv2.VideoWriter(video_save_path, fourcc,
                                               cap.get(cv2.CAP_PROP_FPS), size,
                                               True)

            FONT_SCALE = 1e-3  # Adjust for larger font size in all images
            THICKNESS_SCALE = 1e-3  # Adjust for larger thickness in all images
            TEXT_Y_OFFSET_SCALE = 1e-2  # Adjust for larger Y-offset of text and bounding box
            H, W, _ = frame.shape
            zeros_mask = np.zeros((frame.shape)).astype(np.uint8)
            for bbox, l in zip(bboxes, labels):
                cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
                              PALETTE[l], 1)
                cv2.putText(
                    frame,
                    l, (bbox[0], bbox[1] - int(TEXT_Y_OFFSET_SCALE * H)),
                    fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                    fontScale=min(H, W) * FONT_SCALE,
                    thickness=math.ceil(min(H, W) * THICKNESS_SCALE),
                    color=PALETTE[l])
                zeros_mask = cv2.rectangle(
                    zeros_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
                    color=PALETTE[l],
                    thickness=-1)

            frame = cv2.addWeighted(frame, 1., zeros_mask, .65, 0)
            video_writer.write(frame)
            pbar.update(1)
    video_writer.release
    cap.release()


def panoptic_seg_masks_to_image(masks):
    draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3])
    from mmdet.core.visualization.palette import get_palette
    mask_palette = get_palette('coco', 133)

    from mmdet.core.visualization.image import _get_bias_color
    taken_colors = set([0, 0, 0])
    for i, mask in enumerate(masks):
        color_mask = mask_palette[i]
        while tuple(color_mask) in taken_colors:
            color_mask = _get_bias_color(color_mask)
        taken_colors.add(tuple(color_mask))

        mask = mask.astype(bool)
        draw_img[mask] = color_mask

    return draw_img


def semantic_seg_masks_to_image(masks):
    from mmdet.core.visualization.palette import get_palette
    mask_palette = get_palette('coco', 133)

    draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3])

    for i, mask in enumerate(masks):
        color_mask = mask_palette[i]
        mask = mask.astype(bool)
        draw_img[mask] = color_mask
    return draw_img


def show_video_summarization_result(video_in_path, result, video_save_path):
    frame_indexes = result[OutputKeys.OUTPUT]
    cap = cv2.VideoCapture(video_in_path)
    for i in range(len(frame_indexes)):
        idx = frame_indexes[i]
        success, frame = cap.read()
        if success is False:
            raise Exception(video_in_path,
                            ' can not be correctly decoded by OpenCV.')
        if i == 0:
            size = (frame.shape[1], frame.shape[0])
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            video_writer = cv2.VideoWriter(video_save_path, fourcc,
                                           cap.get(cv2.CAP_PROP_FPS), size,
                                           True)
        if idx == 1:
            video_writer.write(frame)
    video_writer.release()
    cap.release()


def show_image_object_detection_auto_result(img_path,
                                            detection_result,
                                            save_path=None):
    scores = detection_result[OutputKeys.SCORES]
    labels = detection_result[OutputKeys.LABELS]
    bboxes = detection_result[OutputKeys.BOXES]
    img = cv2.imread(img_path)
    assert img is not None, f"Can't read img: {img_path}"

    for (score, label, box) in zip(scores, labels, bboxes):
        cv2.rectangle(img, (int(box[0]), int(box[1])),
                      (int(box[2]), int(box[3])), (0, 0, 255), 2)
        cv2.putText(
            img,
            f'{score:.2f}', (int(box[0]), int(box[1])),
            1,
            1.0, (0, 255, 0),
            thickness=1,
            lineType=8)
        cv2.putText(
            img,
            label, (int(box[0]), int(box[3])),
            1,
            1.0, (0, 255, 0),
            thickness=1,
            lineType=8)

    if save_path is not None:
        cv2.imwrite(save_path, img)
    return img


def depth_to_color(depth):
    colormap = plt.get_cmap('plasma')
    depth_color = (colormap(
        (depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3]
    depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR)
    return depth_color


def make_colorwheel():
    """
    Generates a color wheel for optical flow visualization as presented in:
        Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf

    Code follows the original C++ source code of Daniel Scharstein.
    Code follows the the Matlab source code of Deqing Sun.

    Returns:
        np.ndarray: Color wheel
    """

    RY = 15
    YG = 6
    GC = 4
    CB = 11
    BM = 13
    MR = 6

    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = np.zeros((ncols, 3))
    col = 0

    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
    col = col + RY
    # YG
    colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
    colorwheel[col:col + YG, 1] = 255
    col = col + YG
    # GC
    colorwheel[col:col + GC, 1] = 255
    colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
    col = col + GC
    # CB
    colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
    colorwheel[col:col + CB, 2] = 255
    col = col + CB
    # BM
    colorwheel[col:col + BM, 2] = 255
    colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
    col = col + BM
    # MR
    colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
    colorwheel[col:col + MR, 0] = 255
    return colorwheel


def flow_uv_to_colors(u, v, convert_to_bgr=False):
    """
    Applies the flow color wheel to (possibly clipped) flow components u and v.

    According to the C++ source code of Daniel Scharstein
    According to the Matlab source code of Deqing Sun

    Args:
        u (np.ndarray): Input horizontal flow of shape [H,W]
        v (np.ndarray): Input vertical flow of shape [H,W]
        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.

    Returns:
        np.ndarray: Flow visualization image of shape [H,W,3]
    """
    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
    colorwheel = make_colorwheel()  # shape [55x3]
    ncols = colorwheel.shape[0]
    rad = np.sqrt(np.square(u) + np.square(v))
    a = np.arctan2(-v, -u) / np.pi
    fk = (a + 1) / 2 * (ncols - 1)
    k0 = np.floor(fk).astype(np.int32)
    k1 = k0 + 1
    k1[k1 == ncols] = 0
    f = fk - k0
    for i in range(colorwheel.shape[1]):
        tmp = colorwheel[:, i]
        col0 = tmp[k0] / 255.0
        col1 = tmp[k1] / 255.0
        col = (1 - f) * col0 + f * col1
        idx = (rad <= 1)
        col[idx] = 1 - rad[idx] * (1 - col[idx])
        col[~idx] = col[~idx] * 0.75  # out of range
        # Note the 2-i => BGR instead of RGB
        ch_idx = 2 - i if convert_to_bgr else i
        flow_image[:, :, ch_idx] = np.floor(255 * col)
    return flow_image


def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
    """
    Expects a two dimensional flow image of shape.

    Args:
        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.

    Returns:
        np.ndarray: Flow visualization image of shape [H,W,3]
    """
    assert flow_uv.ndim == 3, 'input flow must have three dimensions'
    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
    if clip_flow is not None:
        flow_uv = np.clip(flow_uv, 0, clip_flow)
    u = flow_uv[:, :, 0]
    v = flow_uv[:, :, 1]
    rad = np.sqrt(np.square(u) + np.square(v))
    rad_max = np.max(rad)
    epsilon = 1e-5
    u = u / (rad_max + epsilon)
    v = v / (rad_max + epsilon)
    return flow_uv_to_colors(u, v, convert_to_bgr)


def flow_to_color(flow):
    flow = flow[0].permute(1, 2, 0).cpu().numpy()
    flow_color = flow_to_image(flow)
    return flow_color


def show_video_depth_estimation_result(depths, video_save_path):
    height, width, layers = depths[0].shape
    out = cv2.VideoWriter(video_save_path, cv2.VideoWriter_fourcc(*'MP4V'), 25,
                          (width, height))
    for (i, img) in enumerate(depths):
        out.write(cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))
    out.release()


def show_image_driving_perception_result(img,
                                         results,
                                         out_file='result.jpg',
                                         if_draw=[1, 1, 1]):
    bboxes = results.get(OutputKeys.BOXES)
    if if_draw[0]:
        for x in bboxes:
            c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
            cv2.rectangle(
                img, c1, c2, [255, 255, 0], thickness=2, lineType=cv2.LINE_AA)

    result = results.get(OutputKeys.MASKS)

    color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3),
                          dtype=np.uint8)

    if if_draw[1]:
        color_area[result[0] == 1] = [0, 255, 0]
    if if_draw[2]:
        color_area[result[1] == 1] = [255, 0, 0]
    color_seg = color_area

    color_mask = np.mean(color_seg, 2)
    msk_idx = color_mask != 0
    img[msk_idx] = img[msk_idx] * 0.5 + color_seg[msk_idx] * 0.5
    if out_file is not None:
        cv2.imwrite(out_file, img[:, :, ::-1])
    return img


def masks_visualization(masks, palette):
    vis_masks = []
    for f in range(masks.shape[0]):
        img_E = Image.fromarray(masks[f])
        img_E.putpalette(palette)
        vis_masks.append(img_E)
    return vis_masks


# This implementation is adopted from LoFTR,
# made public available under the Apache License, Version 2.0,
# at https://github.com/zju3dv/LoFTR


def make_matching_figure(img0,
                         img1,
                         mkpts0,
                         mkpts1,
                         color,
                         kpts0=None,
                         kpts1=None,
                         text=[],
                         dpi=75,
                         path=None):
    # draw image pair
    assert mkpts0.shape[0] == mkpts1.shape[
        0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
    fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
    axes[0].imshow(img0, cmap='gray')
    axes[1].imshow(img1, cmap='gray')
    for i in range(2):  # clear all frames
        axes[i].get_yaxis().set_ticks([])
        axes[i].get_xaxis().set_ticks([])
        for spine in axes[i].spines.values():
            spine.set_visible(False)
    plt.tight_layout(pad=1)

    if kpts0 is not None:
        assert kpts1 is not None
        axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
        axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)

    # draw matches
    if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
        fig.canvas.draw()
        transFigure = fig.transFigure.inverted()
        fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
        fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
        fig.lines = [
            matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
                                    (fkpts0[i, 1], fkpts1[i, 1]),
                                    transform=fig.transFigure,
                                    c=color[i],
                                    linewidth=1) for i in range(len(mkpts0))
        ]

        axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
        axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)

    # put txts
    txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
    fig.text(
        0.01,
        0.99,
        '\n'.join(text),
        transform=fig.axes[0].transAxes,
        fontsize=15,
        va='top',
        ha='left',
        color=txt_color)

    # save or return figure
    if path:
        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        return fig


def match_pair_visualization(img_name0,
                             img_name1,
                             kpts0,
                             kpts1,
                             conf,
                             output_filename='quadtree_match.png',
                             method='QuadTreeAttention'):

    print(f'Found {len(kpts0)} matches')

    # visualize the matches
    img0 = cv2.imread(str(img_name0))
    img1 = cv2.imread(str(img_name1))

    # Draw
    color = cm.jet(conf)
    text = [
        method,
        'Matches: {}'.format(len(kpts0)),
    ]
    fig = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)

    # save the figure
    fig.savefig(str(output_filename), dpi=300, bbox_inches='tight')
