# Copyright (c) Alibaba, Inc. and its affiliates.
"""
Processor class for GeoLayoutLM.
"""

from collections import defaultdict
from typing import Dict, Iterable, List, Union

import cv2
import numpy as np
import PIL
import torch
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms

from modelscope.preprocessors.image import LoadImage


def custom_tokenize(tokenizer, text):
    toks = tokenizer.tokenize('pad ' + text)[1:]
    toks2 = toks[1:] if len(toks) > 0 and toks[0] == '▁' else toks
    return toks2


class ImageProcessor(object):
    r"""
    Construct a GeoLayoutLM image processor
    Args:
        do_preprocess (`bool`): whether to do preprocess to unify the image format,
            resize and convert to tensor.
        do_rescale: only works when we disable do_preprocess.
    """

    def __init__(self,
                 do_preprocess: bool = True,
                 do_resize: bool = False,
                 image_size: Dict[str, int] = None,
                 do_rescale: bool = False,
                 rescale_factor: float = 1. / 255,
                 do_normalize: bool = True,
                 image_mean: Union[float, Iterable[float]] = None,
                 image_std: Union[float, Iterable[float]] = None,
                 apply_ocr: bool = True,
                 **kwargs) -> None:
        self.do_preprocess = do_preprocess
        self.do_resize = do_resize
        self.size = image_size if image_size is not None else {
            'height': 768,
            'width': 768
        }
        self.do_rescale = do_rescale and (not do_preprocess)
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        image_mean = IMAGENET_DEFAULT_MEAN if image_mean is None else image_mean
        image_std = IMAGENET_DEFAULT_STD if image_std is None else image_std
        self.image_mean = (image_mean, image_mean, image_mean) if isinstance(
            image_mean, float) else image_mean
        self.image_std = (image_std, image_std, image_std) if isinstance(
            image_std, float) else image_std
        self.apply_ocr = apply_ocr
        self.kwargs = kwargs

        self.totensor = transforms.ToTensor()

    def preprocess(self, image: Union[np.ndarray, PIL.Image.Image]):
        """ unify the image format, resize and convert to tensor.
        """
        image = LoadImage.convert_to_ndarray(image)[:, :, ::-1]
        size_raw = image.shape[:2]
        if self.do_resize:
            image = cv2.resize(image,
                               (self.size['width'], self.size['height']))
        # convert to pytorch tensor
        image_pt = self.totensor(image)
        return image_pt, size_raw

    def __call__(self, images: Union[list, np.ndarray, PIL.Image.Image, str]):
        """
        Args:
            images: list of np.ndarrays, PIL images or image tensors.
        """
        if not isinstance(images, list):
            images = [images]
        sizes_raw = []
        if self.do_preprocess:
            for i in range(len(images)):
                images[i], size_raw = self.preprocess(images[i])
                sizes_raw.append(size_raw)
        images_pt = torch.stack(images, dim=0)  # [b, c, h, w]
        if self.do_rescale:
            images_pt = images_pt * self.rescale_factor
        if self.do_normalize:
            mu = torch.tensor(self.image_mean).view(1, 3, 1, 1)
            std = torch.tensor(self.image_std).view(1, 3, 1, 1)
            images_pt = (images_pt - mu) / (std + 1e-8)

        # TODO: apply OCR
        ocr_infos = None
        if self.apply_ocr:
            raise NotImplementedError('OCR service is not available yet!')
        if len(sizes_raw) == 0:
            sizes_raw = None
        data = {
            'images': images_pt,
            'ocr_infos': ocr_infos,
            'sizes_raw': sizes_raw
        }
        return data


class OCRUtils(object):

    def __init__(self):
        self.version = 'v0'

    def __call__(self, ocr_infos):
        """
        sort boxes, filtering or other preprocesses
        should return sorted ocr_infos
        """
        raise NotImplementedError


def bound_box(box, height, width):
    # box: [x_tl, y_tl, x_br, y_br] or ...
    assert len(box) == 4 or len(box) == 8
    for i in range(len(box)):
        if i & 1:
            box[i] = max(0, min(box[i], height))
        else:
            box[i] = max(0, min(box[i], width))
    return box


def bbox2pto4p(box2p):
    box4p = [
        box2p[0], box2p[1], box2p[2], box2p[1], box2p[2], box2p[3], box2p[0],
        box2p[3]
    ]
    return box4p


def bbox4pto2p(box4p):
    box2p = [
        min(box4p[0], box4p[2], box4p[4], box4p[6]),
        min(box4p[1], box4p[3], box4p[5], box4p[7]),
        max(box4p[0], box4p[2], box4p[4], box4p[6]),
        max(box4p[1], box4p[3], box4p[5], box4p[7]),
    ]
    return box2p


def stack_tensor_dict(tensor_dicts: List[Dict[str, torch.Tensor]]):
    one_dict = defaultdict(list)
    for td in tensor_dicts:
        for k, v in td.items():
            one_dict[k].append(v)
    res_dict = {}
    for k, v in one_dict.items():
        res_dict[k] = torch.stack(v, dim=0)
    return res_dict


class TextLayoutSerializer(object):

    def __init__(self,
                 max_seq_length: int,
                 max_block_num: int,
                 tokenizer,
                 width=768,
                 height=768,
                 use_roberta_tokenizer: bool = True,
                 ocr_utils: OCRUtils = None):
        self.version = 'v0'

        self.max_seq_length = max_seq_length
        self.max_block_num = max_block_num
        self.tokenizer = tokenizer
        self.width = width
        self.height = height
        self.use_roberta_tokenizer = use_roberta_tokenizer
        self.ocr_utils = ocr_utils

        self.pad_token_id = tokenizer.pad_token_id
        self.cls_token_id = tokenizer.bos_token_id
        self.sep_token_id = tokenizer.eos_token_id
        self.unk_token_id = tokenizer.unk_token_id
        self.cls_bbs_word = [0.0] * 8
        self.cls_bbs_line = [0] * 4

    def label2seq(self, ocr_info: list, label_info: list):
        raise NotImplementedError

    def serialize_single(
        self,
        ocr_info: list = None,
        input_ids: list = None,
        bbox_line: List[List] = None,
        bbox_word: List[List] = None,
        width: int = 768,
        height: int = 768,
    ):
        r"""
        Either ocr_info or (input_ids, bbox_line, bbox_word)
            should be provided.
        If (input_ids, bbox_line, bbox_word) is provided,
            convenient plug into the serialization (customization)
            is offered. The tokens must be organised by blocks and words.
        Else, ocr_info must be provided, to be parsed
            to sequences directly (the simplest way).
        Args:
            ocr_info: [
                {"text": "xx", "box": [a,b,c,d],
                 "words": [{"text": "x", "box": [e,f,g,h]}, ...]},
                ...
            ]
            bbox_line: the coordinate value should match the original image
                (i.e., not be normalized).
        """
        if input_ids is not None:
            assert len(input_ids) == len(bbox_line)
            assert len(input_ids) == len(bbox_word)
            input_ids, bbs_word, bbs_line, first_token_idxes, \
                line_rank_ids, line_rank_inner_ids, word_rank_ids = \
                self.halfseq2seq(input_ids, bbox_line, bbox_word, width, height)
        else:
            assert ocr_info is not None
            input_ids, bbs_word, bbs_line, first_token_idxes, \
                line_rank_ids, line_rank_inner_ids, word_rank_ids = \
                self.ocr_info2seq(ocr_info, width, height)

        token_seq = {}
        token_seq['input_ids'] = torch.ones(
            self.max_seq_length, dtype=torch.int64) * self.pad_token_id
        token_seq['attention_mask'] = torch.zeros(
            self.max_seq_length, dtype=torch.int64)
        token_seq['first_token_idxes'] = torch.zeros(
            self.max_block_num, dtype=torch.int64)
        token_seq['first_token_idxes_mask'] = torch.zeros(
            self.max_block_num, dtype=torch.int64)
        token_seq['bbox_4p_normalized'] = torch.zeros(
            self.max_seq_length, 8, dtype=torch.float32)
        token_seq['bbox'] = torch.zeros(
            self.max_seq_length, 4, dtype=torch.float32)
        token_seq['line_rank_id'] = torch.zeros(
            self.max_seq_length, dtype=torch.int64)  # start from 1
        token_seq['line_rank_inner_id'] = torch.ones(
            self.max_seq_length, dtype=torch.int64)  # 1 2 2 3
        token_seq['word_rank_id'] = torch.zeros(
            self.max_seq_length, dtype=torch.int64)  # start from 1

        # expand using cls and sep tokens
        sep_bbs_word = [width, height] * 4
        sep_bbs_line = [width, height] * 2
        input_ids = [self.cls_token_id] + input_ids + [self.sep_token_id]
        bbs_line = [self.cls_bbs_line] + bbs_line + [sep_bbs_line]
        bbs_word = [self.cls_bbs_word] + bbs_word + [sep_bbs_word]

        # assign
        len_tokens = len(input_ids)
        len_lines = len(first_token_idxes)
        token_seq['input_ids'][:len_tokens] = torch.tensor(input_ids)
        token_seq['attention_mask'][:len_tokens] = 1
        token_seq['first_token_idxes'][:len_lines] = torch.tensor(
            first_token_idxes)
        token_seq['first_token_idxes_mask'][:len_lines] = 1
        token_seq['line_rank_id'][1:len_tokens
                                  - 1] = torch.tensor(line_rank_ids)
        token_seq['line_rank_inner_id'][1:len_tokens - 1] = torch.tensor(
            line_rank_inner_ids)
        token_seq['line_rank_inner_id'] = token_seq[
            'line_rank_inner_id'] * token_seq['attention_mask']
        token_seq['word_rank_id'][1:len_tokens
                                  - 1] = torch.tensor(word_rank_ids)

        token_seq['bbox_4p_normalized'][:len_tokens, :] = torch.tensor(
            bbs_word)
        # word bbox normalization -> [0, 1]
        token_seq['bbox_4p_normalized'][:, [0, 2, 4, 6]] = \
            token_seq['bbox_4p_normalized'][:, [0, 2, 4, 6]] / width
        token_seq['bbox_4p_normalized'][:, [1, 3, 5, 7]] = \
            token_seq['bbox_4p_normalized'][:, [1, 3, 5, 7]] / height

        token_seq['bbox'][:len_tokens, :] = torch.tensor(bbs_line)
        # line bbox -> [0, 1000)
        token_seq['bbox'][:,
                          [0, 2]] = token_seq['bbox'][:, [0, 2]] / width * 1000
        token_seq['bbox'][:,
                          [1, 3]] = token_seq['bbox'][:,
                                                      [1, 3]] / height * 1000
        token_seq['bbox'] = token_seq['bbox'].long()

        return token_seq

    def ocr_info2seq(self, ocr_info: list, width: int, height: int):
        input_ids = []
        bbs_word = []
        bbs_line = []
        first_token_idxes = []
        line_rank_ids = []
        line_rank_inner_ids = []
        word_rank_ids = []

        early_stop = False
        for line_idx, line in enumerate(ocr_info):
            if line_idx == self.max_block_num:
                early_stop = True
            if early_stop:
                break
            lbox = line['box']
            lbox = bound_box(lbox, height, width)
            is_first_word = True
            for word_id, word_info in enumerate(line['words']):
                wtext = word_info['text']
                wbox = word_info['box']
                wbox = bound_box(wbox, height, width)
                wbox4p = bbox2pto4p(wbox)
                if self.use_roberta_tokenizer:
                    wtokens = custom_tokenize(self.tokenizer, wtext)
                else:
                    wtokens = self.tokenizer.tokenize(wtext)
                wtoken_ids = self.tokenizer.convert_tokens_to_ids(wtokens)
                if len(wtoken_ids) == 0:
                    wtoken_ids.append(self.unk_token_id)
                n_tokens = len(wtoken_ids)
                # reserve for cls and sep
                if len(input_ids) + n_tokens > self.max_seq_length - 2:
                    early_stop = True
                    break  # chunking early for long documents
                if is_first_word:
                    first_token_idxes.append(len(input_ids) + 1)
                input_ids.extend(wtoken_ids)
                bbs_word.extend([wbox4p] * n_tokens)
                bbs_line.extend([lbox] * n_tokens)
                word_rank_ids.extend([word_id + 1] * n_tokens)
                line_rank_ids.extend([line_idx + 1] * n_tokens)
                if is_first_word:
                    if len(line_rank_inner_ids
                           ) > 0 and line_rank_inner_ids[-1] == 2:
                        line_rank_inner_ids[-1] = 3
                    line_rank_inner_ids.extend([1] + (n_tokens - 1) * [2])
                    is_first_word = False
                else:
                    line_rank_inner_ids.extend(n_tokens * [2])
        if len(line_rank_inner_ids) > 0 and line_rank_inner_ids[-1] == 2:
            line_rank_inner_ids[-1] = 3

        return input_ids, bbs_word, bbs_line, first_token_idxes, line_rank_ids, \
            line_rank_inner_ids, word_rank_ids

    def halfseq2seq(self, input_ids: list, bbox_line: List[List],
                    bbox_word: List[List], width: int, height: int):
        """
        for convenient plug into the serialization, given the 3 customized sequences.
        They should not contain special tokens like [CLS] or [SEP].
        """
        bbs_word = []
        bbs_line = []
        first_token_idxes = []
        line_rank_ids = []
        line_rank_inner_ids = []
        word_rank_ids = []

        n_real_tokens = len(input_ids)
        lb_prev, wb_prev = None, None
        line_id = 0
        word_id = 1
        for i in range(n_real_tokens):
            lb_now = bbox_line[i]
            wb_now = bbox_word[i]
            line_start = lb_prev is None or lb_now != lb_prev
            word_start = wb_prev is None or wb_now != wb_prev
            lb_prev, wb_prev = lb_now, wb_now

            if len(lb_now) == 8:
                lb_now = bbox4pto2p(lb_now)
            assert len(lb_now) == 4
            lb_now = bound_box(lb_now, height, width)
            if len(wb_now) == 4:
                wb_now = bbox2pto4p(wb_now)
            assert len(wb_now) == 8
            wb_now = bound_box(wb_now, height, width)

            bbs_word.append(wb_now)
            bbs_line.append(lb_now)

            if word_start:
                word_id += 1
            if line_start:
                line_id += 1
                first_token_idxes.append(i + 1)
                if len(line_rank_inner_ids
                       ) > 0 and line_rank_inner_ids[-1] == 2:
                    line_rank_inner_ids[-1] = 3
                line_rank_inner_ids.append(1)
                word_id = 1
            else:
                line_rank_inner_ids.append(2)
            line_rank_ids.append(line_id)
            word_rank_ids.append(word_id)

        if len(line_rank_inner_ids) > 0 and line_rank_inner_ids[-1] == 2:
            line_rank_inner_ids[-1] = 3

        return input_ids, bbs_word, bbs_line, first_token_idxes, \
            line_rank_ids, line_rank_inner_ids, word_rank_ids

    def __call__(
        self,
        ocr_infos: List[List] = None,
        input_ids: list = None,
        bboxes_line: List[List] = None,
        bboxes_word: List[List] = None,
        sizes_raw: list = None,
        **kwargs,
    ):
        n_samples = len(ocr_infos) if ocr_infos is not None else len(input_ids)
        if sizes_raw is None:
            sizes_raw = [(self.height, self.width)] * n_samples
        seqs = []
        if input_ids is not None:
            assert len(input_ids) == len(bboxes_line)
            assert len(input_ids) == len(bboxes_word)
            for input_id, bbox_line, bbox_word, size_raw in zip(
                    input_ids, bboxes_line, bboxes_word, sizes_raw):
                height, width = size_raw
                token_seq = self.serialize_single(None, input_id, bbox_line,
                                                  bbox_word, width, height)
                seqs.append(token_seq)
        else:
            assert ocr_infos is not None, 'For serialization, ocr_infos must not be NoneType!'
            if self.ocr_utils is not None:
                ocr_infos = self.ocr_utils(ocr_infos)
            for ocr_info, size_raw in zip(ocr_infos, sizes_raw):
                height, width = size_raw
                token_seq = self.serialize_single(
                    ocr_info, width=width, height=height)
                seqs.append(token_seq)
        pt_seqs = stack_tensor_dict(seqs)
        return pt_seqs


class Processor(object):
    r"""Construct a GeoLayoutLM processor.

    Args:
        max_seq_length: max length for token
        max_block_num: max number of text lines (blocks or segments)
        img_processor: type of ImageProcessor.
        tokenizer: to tokenize strings.
        use_roberta_tokenizer: Whether the tokenizer is originated from RoBerta tokenizer
            (True by default).
        ocr_utils: a tool to preprocess ocr_infos.
        width: default width. It can be used only when all the images are of the same shape.
        height: default height. It can be used only when all the images are of the same shape.

    In `serialize_from_tokens`, the 3 sequences (i.e., `input_ids`, `bboxes_line`, `bboxes_word`)
        must not contain special tokens like [CLS] or [SEP].
    The boxes in `bboxes_line` and `bboxes_word` can be presented by either 2 points or 4 points.
    The value in boxes should keep original.
    Here is an example of the 3 arguments:
        ```
        input_ids ->
        [[6, 2391, 6, 31833, 6, 10132, 6, 2283, 6, 17730, 6, 2698, 152]]
        bboxes_line ->
        [[[230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38],
            [230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38],
            [257, 155, 338, 191], [257, 155, 338, 191], [257, 155, 338, 191], [257, 155, 338, 191],
            [257, 155, 338, 191]]]
        bboxes_word ->
        [[[231, 2, 267, 2, 267, 38, 231, 38], [231, 2, 267, 2, 267, 38, 231, 38],
            [264, 7, 298, 7, 298, 36, 264, 36], [264, 7, 298, 7, 298, 36, 264, 36],
            [293, 3, 329, 3, 329, 41, 293, 41], [293, 3, 329, 3, 329, 41, 293, 41],
            [330, 4, 354, 4, 354, 39, 330, 39], [330, 4, 354, 4, 354, 39, 330, 39],
            [258, 156, 289, 156, 289, 193, 258, 193], [258, 156, 289, 156, 289, 193, 258, 193],
            [288, 158, 321, 158, 321, 192, 288, 192], [288, 158, 321, 158, 321, 192, 288, 192],
            [321, 156, 336, 156, 336, 190, 321, 190]]]
        ```

    """

    def __init__(self,
                 max_seq_length,
                 max_block_num,
                 img_processor: ImageProcessor,
                 tokenizer=None,
                 use_roberta_tokenizer: bool = True,
                 ocr_utils: OCRUtils = None,
                 width=768,
                 height=768,
                 **kwargs):
        self.img_processor = img_processor
        self.tokenizer = tokenizer
        self.kwargs = kwargs

        self.serializer = TextLayoutSerializer(
            max_seq_length,
            max_block_num,
            tokenizer,
            width,
            height,
            use_roberta_tokenizer=use_roberta_tokenizer,
            ocr_utils=ocr_utils)

    def __call__(
        self,
        images: Union[list, np.ndarray, PIL.Image.Image, str],
        ocr_infos: List[List] = None,
        token_seqs: dict = None,
        sizes_raw: list = None,
    ):
        img_data = self.img_processor(images)
        images = img_data['images']
        ocr_infos = img_data['ocr_infos'] if ocr_infos is None else ocr_infos
        sizes_raw = img_data['sizes_raw'] if sizes_raw is None else sizes_raw
        if token_seqs is None:
            token_seqs = self.serializer(ocr_infos, sizes_raw=sizes_raw)
        else:
            token_seqs = self.serializer(
                None, sizes_raw=sizes_raw, **token_seqs)
        assert token_seqs is not None, 'token_seqs must not be NoneType!'
        batch = {}
        batch['image'] = images
        for k, v in token_seqs.items():
            batch[k] = token_seqs[k]
        return batch

    def serialize_from_tokens(self,
                              images,
                              input_ids,
                              bboxes_line,
                              bboxes_word,
                              sizes_raw=None):
        half_batch = {}
        half_batch['input_ids'] = input_ids
        half_batch['bboxes_line'] = bboxes_line
        half_batch['bboxes_word'] = bboxes_word
        return self(images, None, half_batch, sizes_raw)
