import math
import os
import random
import uuid
from os.path import exists
from tempfile import TemporaryDirectory
from urllib.parse import urlparse

import numpy as np
import torch
import torch.utils.data
import torch.utils.dlpack as dlpack
import torchvision.transforms._transforms_video as transforms
from decord import VideoReader
from torchvision.transforms import Compose

from modelscope.hub.file_download import http_get_file
from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields, ModeKeys
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS


def ReadVideoData(cfg,
                  video_path,
                  num_spatial_crops_override=None,
                  num_temporal_views_override=None):
    """ simple interface to load video frames from file

    Args:
        cfg (Config): The global config object.
        video_path (str): video file path
        num_spatial_crops_override (int): the spatial crops per clip
        num_temporal_views_override (int): the temporal clips per video
    Returns:
        data (Tensor): the normalized video clips for model inputs
    """
    url_parsed = urlparse(video_path)
    if url_parsed.scheme in ('file', '') and exists(
            url_parsed.path):  # Possibly a local file
        data = _decode_video(cfg, video_path, num_temporal_views_override)
    else:
        with TemporaryDirectory() as temporary_cache_dir:
            random_str = uuid.uuid4().hex
            http_get_file(
                url=video_path,
                local_dir=temporary_cache_dir,
                file_name=random_str,
                cookies=None)
            temp_file_path = os.path.join(temporary_cache_dir, random_str)
            data = _decode_video(cfg, temp_file_path,
                                 num_temporal_views_override)

    if num_spatial_crops_override is not None:
        num_spatial_crops = num_spatial_crops_override
        transform = kinetics400_tranform(cfg, num_spatial_crops_override)
    else:
        num_spatial_crops = cfg.TEST.NUM_SPATIAL_CROPS
        transform = kinetics400_tranform(cfg, cfg.TEST.NUM_SPATIAL_CROPS)
    data_list = []
    for i in range(data.size(0)):
        for j in range(num_spatial_crops):
            transform.transforms[1].set_spatial_index(j)
            data_list.append(transform(data[i]))
    return torch.stack(data_list, dim=0)


def kinetics400_tranform(cfg, num_spatial_crops):
    """
    Configs the transform for the kinetics-400 dataset.
    We apply controlled spatial cropping and normalization.
    Args:
        cfg (Config): The global config object.
        num_spatial_crops (int): the spatial crops per clip
    Returns:
        transform_function (Compose): the transform function for input clips
    """
    resize_video = KineticsResizedCrop(
        short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE],
        crop_size=cfg.DATA.TEST_CROP_SIZE,
        num_spatial_crops=num_spatial_crops)
    std_transform_list = [
        transforms.ToTensorVideo(), resize_video,
        transforms.NormalizeVideo(
            mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True)
    ]
    return Compose(std_transform_list)


def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx,
                             num_clips, num_frames, interval, minus_interval):
    """
        Generates the frame index list using interval based sampling.

        Args:
            vid_length (int): the length of the whole video (valid selection range).
            vid_fps (int): the original video fps
            target_fps (int): the normalized video fps
            clip_idx (int):
                -1 for random temporal sampling, and positive values for sampling specific
                clip from the video
            num_clips (int):
                the total clips to be sampled from each video. combined with clip_idx,
                the sampled video is the "clip_idx-th" video from "num_clips" videos.
            num_frames (int): number of frames in each sampled clips.
            interval (int): the interval to sample each frame.
            minus_interval (bool): control the end index

        Returns:
            index (tensor): the sampled frame indexes
    """
    if num_frames == 1:
        index = [random.randint(0, vid_length - 1)]
    else:
        # transform FPS
        clip_length = num_frames * interval * vid_fps / target_fps

        max_idx = max(vid_length - clip_length, 0)
        if num_clips == 1:
            start_idx = max_idx / 2
        else:
            start_idx = clip_idx * math.floor(max_idx / (num_clips - 1))
        if minus_interval:
            end_idx = start_idx + clip_length - interval
        else:
            end_idx = start_idx + clip_length - 1

        index = torch.linspace(start_idx, end_idx, num_frames)
        index = torch.clamp(index, 0, vid_length - 1).long()

    return index


def _decode_video_frames_list(cfg,
                              frames_list,
                              vid_fps,
                              num_temporal_views_override=None):
    """
        Decodes the video given the numpy frames.
        Args:
            cfg          (Config): The global config object.
            frames_list  (list):  all frames for a video, the frames should be numpy array.
            vid_fps      (int):  the fps of this video.
            num_temporal_views_override (int): the temporal clips per video
        Returns:
            frames            (Tensor): video tensor data
    """
    assert isinstance(frames_list, list)
    if num_temporal_views_override is not None:
        num_clips_per_video = num_temporal_views_override
    else:
        num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS

    frame_list = []
    for clip_idx in range(num_clips_per_video):
        # for each clip in the video,
        # a list is generated before decoding the specified frames from the video
        list_ = _interval_based_sampling(
            len(frames_list),
            vid_fps,
            cfg.DATA.TARGET_FPS,
            clip_idx,
            num_clips_per_video,
            cfg.DATA.NUM_INPUT_FRAMES,
            cfg.DATA.SAMPLING_RATE,
            cfg.DATA.MINUS_INTERVAL,
        )
        frames = None
        frames = torch.from_numpy(
            np.stack([frames_list[index] for index in list_.tolist()], axis=0))
        frame_list.append(frames)
    frames = torch.stack(frame_list)
    del vr
    return frames


def _decode_video(cfg, path, num_temporal_views_override=None):
    """
        Decodes the video given the numpy frames.
        Args:
            cfg          (Config): The global config object.
            path          (str): video file path.
            num_temporal_views_override (int): the temporal clips per video
        Returns:
            frames            (Tensor): video tensor data
    """
    vr = VideoReader(path)
    if num_temporal_views_override is not None:
        num_clips_per_video = num_temporal_views_override
    else:
        num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS

    frame_list = []
    for clip_idx in range(num_clips_per_video):
        # for each clip in the video,
        # a list is generated before decoding the specified frames from the video
        list_ = _interval_based_sampling(
            len(vr),
            vr.get_avg_fps(),
            cfg.DATA.TARGET_FPS,
            clip_idx,
            num_clips_per_video,
            cfg.DATA.NUM_INPUT_FRAMES,
            cfg.DATA.SAMPLING_RATE,
            cfg.DATA.MINUS_INTERVAL,
        )
        frames = None
        if path.endswith('.avi'):
            append_list = torch.arange(0, list_[0], 4)
            frames = dlpack.from_dlpack(
                vr.get_batch(torch.cat([append_list,
                                        list_])).to_dlpack()).clone()
            frames = frames[append_list.shape[0]:]
        else:
            frames = dlpack.from_dlpack(
                vr.get_batch(list_).to_dlpack()).clone()
        frame_list.append(frames)
    frames = torch.stack(frame_list)
    del vr
    return frames


class KineticsResizedCrop(object):
    """Perform resize and crop for kinetics-400 dataset
    Args:
        short_side_range (list): The length of short side range. In inference, this should be [256, 256]
        crop_size         (int): The cropped size for frames.
        num_spatial_crops (int): The number of the cropped spatial regions in each video.
    """

    def __init__(
        self,
        short_side_range,
        crop_size,
        num_spatial_crops=1,
    ):
        self.idx = -1
        self.short_side_range = short_side_range
        self.crop_size = int(crop_size)
        self.num_spatial_crops = num_spatial_crops

    def _get_controlled_crop(self, clip):
        """Perform controlled crop for video tensor.
        Args:
            clip (Tensor): the video data, the shape is [T, C, H, W]
        """
        _, _, clip_height, clip_width = clip.shape

        length = self.short_side_range[0]

        if clip_height < clip_width:
            new_clip_height = int(length)
            new_clip_width = int(clip_width / clip_height * new_clip_height)
            new_clip = torch.nn.functional.interpolate(
                clip, size=(new_clip_height, new_clip_width), mode='bilinear')
        else:
            new_clip_width = int(length)
            new_clip_height = int(clip_height / clip_width * new_clip_width)
            new_clip = torch.nn.functional.interpolate(
                clip, size=(new_clip_height, new_clip_width), mode='bilinear')
        x_max = int(new_clip_width - self.crop_size)
        y_max = int(new_clip_height - self.crop_size)
        if self.num_spatial_crops == 1:
            x = x_max // 2
            y = y_max // 2
        elif self.num_spatial_crops == 3:
            if self.idx == 0:
                if new_clip_width == length:
                    x = x_max // 2
                    y = 0
                elif new_clip_height == length:
                    x = 0
                    y = y_max // 2
            elif self.idx == 1:
                x = x_max // 2
                y = y_max // 2
            elif self.idx == 2:
                if new_clip_width == length:
                    x = x_max // 2
                    y = y_max
                elif new_clip_height == length:
                    x = x_max
                    y = y_max // 2
        return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]

    def _get_random_crop(self, clip):
        _, _, clip_height, clip_width = clip.shape

        short_side = min(clip_height, clip_width)
        long_side = max(clip_height, clip_width)
        new_short_side = int(random.uniform(*self.short_side_range))
        new_long_side = int(long_side / short_side * new_short_side)
        if clip_height < clip_width:
            new_clip_height = new_short_side
            new_clip_width = new_long_side
        else:
            new_clip_height = new_long_side
            new_clip_width = new_short_side

        new_clip = torch.nn.functional.interpolate(
            clip, size=(new_clip_height, new_clip_width), mode='bilinear')

        x_max = int(new_clip_width - self.crop_size)
        y_max = int(new_clip_height - self.crop_size)
        x = int(random.uniform(0, x_max))
        y = int(random.uniform(0, y_max))
        return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]

    def set_spatial_index(self, idx):
        """Set the spatial cropping index for controlled cropping..
        Args:
            idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively.
        """
        self.idx = idx

    def __call__(self, clip):
        return self._get_controlled_crop(clip)


@PREPROCESSORS.register_module(
    Fields.cv, module_name=Preprocessors.movie_scene_segmentation_preprocessor)
class MovieSceneSegmentationPreprocessor(Preprocessor):

    def __init__(self, *args, **kwargs):
        """
        movie scene segmentation preprocessor
        """
        super().__init__(*args, **kwargs)

        self.is_train = kwargs.pop('is_train', True)
        self.preprocessor_train_cfg = kwargs.pop(ModeKeys.TRAIN, None)
        self.preprocessor_test_cfg = kwargs.pop(ModeKeys.EVAL, None)
        self.num_keyframe = kwargs.pop('num_keyframe', 3)

        from .movie_scene_segmentation import get_transform
        self.train_transform = get_transform(self.preprocessor_train_cfg)
        self.test_transform = get_transform(self.preprocessor_test_cfg)

    def train(self):
        self.is_train = True
        return

    def eval(self):
        self.is_train = False
        return

    @type_assert(object, object)
    def __call__(self, results):
        if self.is_train:
            transforms = self.train_transform
        else:
            transforms = self.test_transform

        results = torch.stack(transforms(results), dim=0)
        results = results.view(-1, self.num_keyframe, 3, 224, 224)
        return results
