# Copyright (c) Alibaba, Inc. and its affiliates.
import types
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Union

import torch
import torch.distributed as dist
import transformers
from packaging import version
from torch import nn
from transformers import PreTrainedModel
from transformers.generation import GreedySearchDecoderOnlyOutput  # noqa
from transformers.generation import (GreedySearchEncoderDecoderOutput,
                                     LogitsProcessorList,
                                     SampleDecoderOnlyOutput,
                                     SampleEncoderDecoderOutput,
                                     StoppingCriteriaList,
                                     validate_stopping_criteria)

from modelscope.pipelines.base import Input
from modelscope.utils.constant import Frameworks
from modelscope.utils.device import device_placement


class StreamingOutputMixin:

    def stream_generate(self, *args, **kwargs) -> Generator:
        """
        Support the input of Model and Pipeline.
        The output is a `Generator` type,
        which conforms to the output standard of modelscope.
        """
        raise NotImplementedError


class PipelineStreamingOutputMixin(StreamingOutputMixin):

    def stream_generate(self, input: Union[Input, List[Input]], *args,
                        **kwargs) -> Generator:
        """
        Similar to the `Pipeline.__call__` method.
        it supports the input that the pipeline can accept,
        and also supports batch input.

        self.model must be a subclass of StreamingOutputMixin
        and implement the stream method.
        """
        assert isinstance(self.model, StreamingOutputMixin
                          ), 'pipeline.model must be StreamingOutputMixin!'
        if (self.model or (self.has_multiple_models and self.models[0])):
            if not self._model_prepare:
                self.prepare_model()

        batch_size = kwargs.pop('batch_size', None)
        preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
            **kwargs)

        if isinstance(input, list):
            model_input_list = [
                self._preprocess_with_check(i, preprocess_params)
                for i in input
            ]

            if batch_size is None:
                output = []
                for ele in model_input_list:
                    output.append(
                        self._stream_single(ele, forward_params,
                                            postprocess_params))
            else:
                output = self._stream_batch(model_input_list, batch_size,
                                            forward_params, postprocess_params)

        else:
            model_input = self._preprocess_with_check(input, preprocess_params)
            output = self._stream_single(model_input, forward_params,
                                         postprocess_params)
        return output

    def _preprocess_with_check(
            self, input: Input,
            preprocess_params: Dict[str, Any]) -> Dict[str, Any]:
        self._check_input(input)
        return self.preprocess(input, **preprocess_params)

    def _stream_single(self, model_input: Dict[str, Any],
                       forward_params: Dict[str, Any],
                       postprocess_params: Dict[str, Any]) -> Generator:

        with device_placement(self.framework, self.device_name):
            if self.framework == Frameworks.torch:
                with torch.no_grad():
                    if self._auto_collate:
                        model_input = self._collate_fn(model_input)
                    stream = self.model.stream_generate(
                        model_input, **forward_params)
            else:
                stream = self.model.stream_generate(model_input,
                                                    **forward_params)

            for out in stream:
                out = self.postprocess(out, **postprocess_params)
                self._check_output(out)
                yield out

    def _stream_batch(self, model_input_list: List[Dict[str, Any]],
                      batch_size: int, forward_params: Dict[str, Any],
                      postprocess_params: Dict[str, Any]) -> Generator:

        stream_list = []
        real_batch_sizes = []
        with device_placement(self.framework, self.device_name):
            for i in range(0, len(model_input_list), batch_size):
                end = min(i + batch_size, len(model_input_list))
                real_batch_size = end - i
                real_batch_sizes.append(real_batch_size)

                batched_out = self._batch(model_input_list[i:end])
                if self.framework == Frameworks.torch:
                    with torch.no_grad():
                        if self._auto_collate:
                            batched_out = self._collate_fn(batched_out)
                        stream_list.append(
                            self.model.stream_generate(batched_out,
                                                       **forward_params))
                else:
                    stream_list.append(
                        self.model.stream_generate(batched_out,
                                                   **forward_params))

            output_list = [None] * len(model_input_list)
            stop_streams = 0
            while stop_streams < len(stream_list):
                stop_streams = 0
                for i, (stream, real_batch_size) in enumerate(
                        zip(stream_list, real_batch_sizes)):
                    try:
                        batched_out = next(stream)
                        for batch_idx in range(real_batch_size):
                            out = {}
                            for k, element in batched_out.items():
                                if element is not None:
                                    if isinstance(element, (tuple, list)):
                                        if isinstance(element[0],
                                                      torch.Tensor):
                                            out[k] = type(element)(
                                                e[batch_idx:batch_idx + 1]
                                                for e in element)
                                        else:
                                            # Compatible with traditional pipelines
                                            out[k] = element[batch_idx]
                                    else:
                                        out[k] = element[batch_idx:batch_idx
                                                         + 1]
                            out = self.postprocess(out, **postprocess_params)
                            self._check_output(out)
                            output_index = i * batch_size + batch_idx
                            output_list[output_index] = out
                    except StopIteration:
                        stop_streams += 1
                yield output_list

        return output_list


class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):

    def stream_generate(self, *args, **kwargs) -> Generator:
        model = self if isinstance(self, PreTrainedModel) else self.model
        assert isinstance(model, PreTrainedModel), \
            'self or self.model must be `PretrainedModel`!'
        with self._replace_generate(model):
            return model.generate(*args, **kwargs)

    @contextmanager
    def _replace_generate(self, model: PreTrainedModel) -> Generator:
        if version.parse(transformers.__version__) >= version.parse('4.43.0'):
            greedy_search_name = 'stream_greedy_search'
            sample_name = '_sample'
        elif version.parse(
                transformers.__version__) >= version.parse('4.39.0'):
            greedy_search_name = '_greedy_search'
            sample_name = '_sample'
        else:
            greedy_search_name = 'greedy_search'
            sample_name = 'sample'
        origin_greedy_search = getattr(model, greedy_search_name)
        origin_sample = getattr(model, sample_name)
        setattr(model, greedy_search_name,
                types.MethodType(self.stream_greedy_search, model))
        setattr(model, sample_name, types.MethodType(self.stream_sample,
                                                     model))
        yield
        setattr(model, greedy_search_name, origin_greedy_search)
        setattr(model, sample_name, origin_sample)

    @staticmethod
    def stream_greedy_search(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        **model_kwargs,
    ) -> Generator:
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
        )
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
        )
        if max_length is not None:
            warnings.warn(
                '`max_length` is deprecated in this function, use'
                ' `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.',
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(
                stopping_criteria, max_length)
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(
            input_ids.device) if eos_token_id is not None else None
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_attentions = (
            output_attentions if output_attentions is not None else
            self.generation_config.output_attentions)
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else
            self.generation_config.output_hidden_states)
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else
            self.generation_config.return_dict_in_generate)

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate
                                    and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate
                                  and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate
                                       and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs['encoder_outputs'].get(
                'attentions') if output_attentions else None
            encoder_hidden_states = (
                model_kwargs['encoder_outputs'].get('hidden_states')
                if output_hidden_states else None)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(
            input_ids.shape[0], dtype=torch.long, device=input_ids.device)

        this_peer_finished = False  # used by synced_gpus only
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(
                    0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(
                input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_tokens_scores = logits_processor(input_ids, next_token_logits)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_tokens_scores, )
                if output_attentions:
                    decoder_attentions += ((outputs.decoder_attentions, ) if
                                           self.config.is_encoder_decoder else
                                           (outputs.attentions, ))
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions, )

                if output_hidden_states:
                    decoder_hidden_states += ((outputs.decoder_hidden_states, )
                                              if self.config.is_encoder_decoder
                                              else (outputs.hidden_states, ))

            # argmax
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError(
                        'If `eos_token_id` is defined, make sure that `pad_token_id` is defined.'
                    )
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
                    1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            # return Generator for stream
            if return_dict_in_generate:
                if self.config.is_encoder_decoder:
                    yield GreedySearchEncoderDecoderOutput(
                        sequences=input_ids,
                        scores=scores,
                        encoder_attentions=encoder_attentions,
                        encoder_hidden_states=encoder_hidden_states,
                        decoder_attentions=decoder_attentions,
                        cross_attentions=cross_attentions,
                        decoder_hidden_states=decoder_hidden_states,
                    )
                else:
                    yield GreedySearchDecoderOnlyOutput(
                        sequences=input_ids,
                        scores=scores,
                        attentions=decoder_attentions,
                        hidden_states=decoder_hidden_states,
                    )
            else:
                yield input_ids

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder)

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id_tensor is not None:
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
                        eos_token_id_tensor.unsqueeze(1)).prod(dim=0))

                # stop when each sentence is finished
                if unfinished_sequences.max() == 0:
                    this_peer_finished = True

            # stop if we exceed the maximum length
            if stopping_criteria(input_ids, scores):
                this_peer_finished = True

            if this_peer_finished and not synced_gpus:
                break

    @staticmethod
    def stream_sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        **model_kwargs,
    ) -> Generator:
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
        )
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
        )
        if max_length is not None:
            warnings.warn(
                '`max_length` is deprecated in this function, use'
                ' `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.',
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(
                stopping_criteria, max_length)
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList(
        )
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(
            input_ids.device) if eos_token_id is not None else None
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_attentions = (
            output_attentions if output_attentions is not None else
            self.generation_config.output_attentions)
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else
            self.generation_config.output_hidden_states)
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else
            self.generation_config.return_dict_in_generate)

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate
                                    and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate
                                  and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate
                                       and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs['encoder_outputs'].get(
                'attentions') if output_attentions else None
            encoder_hidden_states = (
                model_kwargs['encoder_outputs'].get('hidden_states')
                if output_hidden_states else None)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(
            input_ids.shape[0], dtype=torch.long, device=input_ids.device)

        this_peer_finished = False  # used by synced_gpus only
        # auto-regressive generation
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(
                    0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs
            model_kwargs = self._get_initial_cache_position(
                input_ids, model_kwargs)
            model_inputs = self.prepare_inputs_for_generation(
                input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores, )
                if output_attentions:
                    decoder_attentions += ((outputs.decoder_attentions, ) if
                                           self.config.is_encoder_decoder else
                                           (outputs.attentions, ))
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions, )

                if output_hidden_states:
                    decoder_hidden_states += ((outputs.decoder_hidden_states, )
                                              if self.config.is_encoder_decoder
                                              else (outputs.hidden_states, ))

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError(
                        'If `eos_token_id` is defined, make sure that `pad_token_id` is defined.'
                    )
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
                    1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            # return Generator for stream
            if return_dict_in_generate:
                if self.config.is_encoder_decoder:
                    yield SampleEncoderDecoderOutput(
                        sequences=input_ids,
                        scores=scores,
                        encoder_attentions=encoder_attentions,
                        encoder_hidden_states=encoder_hidden_states,
                        decoder_attentions=decoder_attentions,
                        cross_attentions=cross_attentions,
                        decoder_hidden_states=decoder_hidden_states,
                    )
                else:
                    yield SampleDecoderOnlyOutput(
                        sequences=input_ids,
                        scores=scores,
                        attentions=decoder_attentions,
                        hidden_states=decoder_hidden_states,
                    )
            else:
                yield input_ids

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder)

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id_tensor is not None:
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
                        eos_token_id_tensor.unsqueeze(1)).prod(dim=0))

                # stop when each sentence is finished
                if unfinished_sequences.max() == 0:
                    this_peer_finished = True

            # stop if we exceed the maximum length
            if stopping_criteria(input_ids, scores):
                this_peer_finished = True

            if this_peer_finished and not synced_gpus:
                break


def add_stream_generate(model: PreTrainedModel):
    pretrained_class = type(model)
    parent_classes = (pretrained_class, PretrainedModelStreamingOutputMixin)
    new_model = type(pretrained_class.__name__, parent_classes, {})(
        model.config)
    new_model.__dict__.update(model.__dict__)
    return new_model
