# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
from typing import List, Optional, Union

import paddle
import paddle.distributed as dist
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor
from paddle.common_ops_import import convert_dtype
from paddle.utils import map_structure

from ......utils import logging
from ..transformers.model_outputs import ModelOutput
from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
from .logits_process import (
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    HammingDiversityLogitsProcessor,
    LogitsProcessor,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    TopKProcess,
    TopPProcess,
)
from .stopping_criteria import (
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)

__all__ = [
    "GenerationMixin",
    "BeamSearchScorer",
    "BeamHypotheses",
    "LogitsProcessorList",
    "LogitsProcessor",
    "MinLengthLogitsProcessor",
    "RepetitionPenaltyLogitsProcessor",
    "TopKProcess",
    "TopPProcess",
    "get_unfinished_flag",
]


def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
    """get scale value by dtype

    Args:
        dtype (str): the string dtype value

    Returns:
        float: the scale value
    """
    if dtype is None:
        dtype = paddle.get_default_dtype()

    dtype = convert_dtype(dtype)
    scale_value = 1e6

    # TODO(wj-Mcaf): support int8, int4 dtypes later
    if dtype == "float16":
        scale_value = 1e4

    if return_positive:
        return scale_value
    return -1 * scale_value


def get_unfinished_flag(
    input_ids: Tensor,
    unfinished_flag: Tensor,
    eos_token_id: Union[int, List[int], List[List[int]]],
) -> Tensor:
    """get unfinished flag for generation step

    Args:
        input_ids (Tensor): the input_ids
        eos_token_id (Union[int, list[int], list[list[int]]]): the end os sentence flag, which can be:
            * single token id, eg: 10
            * multiple token ids to stop generation, eg: [10, 10]
            * some more tokens to stop generations, eg: [[10], [20, 20], [30, 30, 30]]

    Returns:
        Tensor: the unfinished flag tensor
    """
    if isinstance(eos_token_id, int):
        unfinished_flag = paddle.logical_and(
            unfinished_flag, input_ids[:, -1:] != eos_token_id
        )
    else:
        batch_unfinish_flag = None
        for batch_eos_token_id in eos_token_id:
            if batch_unfinish_flag is None:
                batch_unfinish_flag = ~get_unfinished_flag(
                    input_ids, unfinished_flag, batch_eos_token_id
                )
            else:
                batch_unfinish_flag = paddle.logical_or(
                    batch_unfinish_flag,
                    ~get_unfinished_flag(
                        input_ids, unfinished_flag, batch_eos_token_id
                    ),
                )

        unfinished_flag = ~batch_unfinish_flag
    return unfinished_flag


class BeamHypotheses:
    def __init__(self, num_beams, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = get_scale_by_dtype()

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.beams)

    def add(self, hyp, sum_logprobs, origin_len=0):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / (
            ((hyp.shape[-1] - origin_len + 5) / 6) ** self.length_penalty
        )
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                sorted_next_scores = sorted(
                    [(s, idx) for idx, (s, _) in enumerate(self.beams)]
                )
                del self.beams[sorted_next_scores[0][1]]
                self.worst_score = sorted_next_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len, origin_len=0):
        """
        If there are enough hypotheses and that none of the hypotheses being
        generated can become better than the worst one in the heap, then we
        are done with this sentence.
        """
        if len(self) < self.num_beams:
            return False
        elif self.early_stopping:
            return True
        else:
            cur_score = (
                best_sum_logprobs
                / ((cur_len - origin_len + 5) / 6) ** self.length_penalty
            )
            ret = self.worst_score >= cur_score
            return ret


class BeamSearchScorer(object):
    """
    implementing standard beam search decoding.
    """

    def __init__(
        self,
        batch_size,
        max_length,
        num_beams,
        length_penalty=1.0,
        do_early_stopping=False,
        num_beam_hyps_to_keep=1,
        num_beam_groups=1,
    ):
        self.max_length = max_length
        self.num_beams = num_beams
        self.length_penalty = length_penalty
        self.do_early_stopping = do_early_stopping
        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
        self.num_beam_groups = num_beam_groups
        self.group_size = self.num_beams // self.num_beam_groups

        self._is_init = False
        self._beam_hyps = [
            BeamHypotheses(
                num_beams=self.num_beams,
                length_penalty=self.length_penalty,
                early_stopping=self.do_early_stopping,
            )
            for _ in range(batch_size)
        ]
        self._done = paddle.to_tensor([0 for _ in range(batch_size)], dtype="int64")

        if not isinstance(num_beams, int) or num_beams <= 1:
            raise ValueError(
                "`num_beams` has to be an integer strictly greater than 1, but "
                "received {}. For `num_beams` == 1, one should make use of "
                "`greedy_search` instead.".format(num_beams)
            )

        if (
            not isinstance(num_beam_groups, int)
            or (num_beam_groups > num_beams)
            or (num_beams % num_beam_groups != 0)
        ):
            raise ValueError(
                "`num_beam_groups` has to be an integer smaller or equal than "
                "`num_beams` and `num_beams` has to be divisible by "
                "`num_beam_groups`, but received num_beam_groups={}, num_beams="
                "{}.".format(num_beam_groups, num_beams)
            )

    @property
    def is_done(self):
        return paddle.min(self._done) == 1

    def process(
        self,
        input_ids,
        next_scores,
        next_tokens,
        next_indices,
        origin_len=0,
        pad_token_id=None,
        eos_token_id=None,
    ):
        cur_len = input_ids.shape[-1]
        batch_size = len(self._beam_hyps)
        assert batch_size == (input_ids.shape[0] // self.group_size)

        next_beam_scores = paddle.zeros(
            [batch_size, self.group_size], dtype=next_scores.dtype
        )
        next_beam_tokens = paddle.zeros(
            [batch_size, self.group_size], dtype=next_tokens.dtype
        )
        next_beam_indices = paddle.zeros(
            [batch_size, self.group_size], dtype=next_indices.dtype
        )

        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx] == 1:
                assert (
                    len(beam_hyp) >= self.num_beams
                ), "Batch can only be done if at least {} beams have been generated".format(
                    self.num_beams
                )
                assert (
                    eos_token_id is not None and pad_token_id is not None
                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                # pad the batch
                next_beam_scores[batch_idx, :] = 0
                next_beam_tokens[batch_idx, :] = pad_token_id
                next_beam_indices[batch_idx, :] = 0
                continue

            # next tokens for this sentence
            beam_idx = 0
            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
                zip(
                    next_tokens[batch_idx],
                    next_scores[batch_idx],
                    next_indices[batch_idx],
                )
            ):
                batch_beam_idx = batch_idx * self.group_size + next_index
                # add to generated hypotheses if end of sentence
                if (eos_token_id is not None) and (next_token.item() == eos_token_id):
                    # If beam_token does not belong to top num_beams tokens,
                    # it should not be added
                    is_beam_token_worse_than_top_num_beams = (
                        beam_token_rank >= self.group_size
                    )
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    beam_hyp.add(
                        input_ids[batch_beam_idx.item()].clone(),
                        next_score.item(),
                        origin_len,
                    )

                else:
                    # add next predicted token since it is not eos_token
                    next_beam_scores[batch_idx, beam_idx] = next_score
                    next_beam_tokens[batch_idx, beam_idx] = next_token.item()
                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx.item()
                    beam_idx += 1

                # once the beam for next step is full, don't add more tokens to it.
                if beam_idx == self.group_size:
                    break

            if beam_idx < self.group_size:
                raise ValueError(
                    "At most {} tokens in `next_tokens[batch_idx]` can be equal "
                    "to `eos_token_id: {}`. Make sure `next_tokens[batch_idx]` "
                    "are corrected.".format(self.group_size, eos_token_id)
                )

            # Check if we are done so that we can save a pad step if all(done)
            if beam_hyp.is_done(
                next_scores[batch_idx].max().item(), cur_len, origin_len
            ):
                self._done[batch_idx] = 1

        return {
            "next_beam_scores": next_beam_scores.reshape([-1]),
            "next_beam_tokens": next_beam_tokens.reshape([-1]),
            "next_beam_indices": next_beam_indices.reshape([-1]),
        }

    def finalize(
        self,
        input_ids,
        final_beam_scores,
        final_beam_tokens,
        final_beam_indices,
        origin_len=0,
        pad_token_id=None,
        eos_token_id=None,
    ):
        batch_size = len(self._beam_hyps)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx] == 1:
                continue

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_hyp.add(final_tokens, final_score, origin_len=origin_len)

        # select the best hypotheses
        sent_lengths = paddle.zeros(
            [batch_size * self.num_beam_hyps_to_keep], dtype=input_ids.dtype
        )
        best = []

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_score, best_hyp = sorted_hyps.pop()
                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
                best.append([best_hyp, best_score])

        # prepare for adding eos
        sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
        decoded = paddle.zeros(
            [batch_size * self.num_beam_hyps_to_keep, sent_max_len],
            dtype=input_ids.dtype,
        )
        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded[:, :] = pad_token_id
        decoded_score = paddle.zeros([batch_size * self.num_beam_hyps_to_keep, 1])

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, (hypo, score) in enumerate(best):
            decoded[i, : sent_lengths[i].item()] = hypo.cpu().numpy()
            decoded_score[i] = score
            if sent_lengths[i] < self.max_length:
                decoded[i, sent_lengths[i].item()] = eos_token_id
        return decoded, decoded_score


class GenerationMixin(object):
    r"""
    This class implements the interface for generation task.

    It's used as the base class of `paddlenlp.transformers.PretrainedModel
    <https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.model_utils.html>`__.
    """

    # enable `to_static` method for CausalLM Model
    enable_to_static_method = False

    @staticmethod
    def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
        batch_size = 1
        if bos_token_id is None:
            raise ValueError(
                "`bos_token_id` should be defined when no " "`input_ids` are provided."
            )
        if encoder_output is not None:
            batch_size = encoder_output.shape[0]
        return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id

    @staticmethod
    def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
        is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
            input_ids == pad_token_id
        ).item()
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
            (eos_token_id is not None) and (pad_token_id != eos_token_id)
        )
        if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
            attention_mask = (input_ids == pad_token_id).astype(
                paddle.get_default_dtype()
            ) * get_scale_by_dtype(return_positive=False)
        else:
            attention_mask = paddle.zeros_like(
                input_ids, dtype=paddle.get_default_dtype()
            )
        return paddle.unsqueeze(attention_mask, axis=[1, 2])

    @staticmethod
    def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
        is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
            input_ids == pad_token_id
        ).item()
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
            (eos_token_id is not None) and (pad_token_id != eos_token_id)
        )
        if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
            seq_len = paddle.sum(input_ids != pad_token_id, axis=1).unsqueeze(-1)
        else:
            seq_len = paddle.full(
                (input_ids.shape[0], 1), input_ids.shape[1], dtype="int64"
            )
        return seq_len

    def get_logits_processor(
        self,
        min_length=None,
        max_length=None,
        eos_token_id=None,
        forced_bos_token_id=None,
        forced_eos_token_id=None,
        num_beams=1,
        num_beam_groups=1,
        diversity_rate=0.0,
        repetition_penalty=None,
        no_repeat_ngram_size=None,
        logits_processors=None,
    ):
        processors = LogitsProcessorList()

        if min_length is not None and eos_token_id is not None and min_length > -1:
            processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
        if num_beam_groups > 1 and diversity_rate > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_rate=diversity_rate,
                    num_beams=num_beams,
                    num_beam_groups=num_beam_groups,
                )
            )
        if repetition_penalty is not None and repetition_penalty != 1.0:
            processors.append(
                RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
            )
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
        if forced_bos_token_id is not None:
            processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(
                ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
            )
        # TODO
        # Add more pre_processing for distribution

        if logits_processors is not None:
            custom_processors = LogitsProcessorList()
            custom_processors_type = [type(lp) for lp in logits_processors]

            for processor in processors:
                if type(processor) not in custom_processors_type:
                    custom_processors.append(processor)
            custom_processors.extend(logits_processors)

            return custom_processors
        else:
            return processors

    @staticmethod
    def expand_inputs_for_generation(
        input_ids, expand_size, attention_mask=None, **model_kwargs
    ):

        index = paddle.tile(
            paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1),
            [1, expand_size],
        ).reshape([-1])

        input_ids = paddle.gather(input_ids, index)

        if attention_mask is not None:
            model_kwargs["attention_mask"] = paddle.gather(attention_mask, index)

        if (
            "token_type_ids" in model_kwargs
            and model_kwargs["token_type_ids"] is not None
        ):
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index)

        if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
            position_ids = model_kwargs["position_ids"]
            model_kwargs["position_ids"] = paddle.gather(position_ids, index)

        if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
            seq_len = model_kwargs["seq_len"]
            model_kwargs["seq_len"] = paddle.gather(seq_len, index)

        if (
            "encoder_output" in model_kwargs
            and model_kwargs["encoder_output"] is not None
        ):
            encoder_output = model_kwargs["encoder_output"]
            model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)

        if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
            role_ids = model_kwargs["role_ids"]
            model_kwargs["role_ids"] = paddle.gather(role_ids, index)

        return input_ids, model_kwargs

    @staticmethod
    def update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=False
    ):
        # Update the model inputs during generation.
        # Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
        # and they contain pad value, the result vectors updated by this method
        # may be different from expected. In this case, you need to rewrite the
        # method.

        # update cache
        if (
            isinstance(outputs, tuple)
            and len(outputs) > 1
            and not isinstance(outputs[1], paddle.Tensor)
        ):
            model_kwargs["cache"] = outputs[1]
            model_kwargs["past_key_values"] = outputs[1]

        if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
            model_kwargs["cache"] = outputs.past_key_values
            model_kwargs["past_key_values"] = outputs.past_key_values

        # update token_type_ids with last value
        if (
            "token_type_ids" in model_kwargs
            and model_kwargs["token_type_ids"] is not None
        ):
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = paddle.concat(
                [token_type_ids, token_type_ids[:, -1:]], axis=-1
            )

        # update position_ids
        if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
            position_ids = model_kwargs["position_ids"]
            model_kwargs["position_ids"] = paddle.concat(
                [position_ids, position_ids[..., -1:] + 1], axis=-1
            )

        # update attention_mask
        if not is_encoder_decoder and "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            # nn.Pad2D don't support the data type `bool`
            if convert_dtype(attention_mask.dtype) == "bool":
                attention_mask = paddle.cast(attention_mask, "int64")
            if len(attention_mask.shape) == 4:
                cur_device = paddle.get_device()
                if cur_device.split(":")[0] == "npu":
                    attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(
                        attention_mask
                    )
                    attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
                else:
                    attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(
                        attention_mask
                    )
                    attention_mask = nn.Pad2D(
                        [0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False)
                    )(attention_mask)

                dtype = convert_dtype(attention_mask.dtype)
                if "int" in dtype:
                    attention_mask[:, :, -1, -1] = 1
                elif "float" in dtype:
                    attention_mask[:, :, -1, -1] = 0.0
                else:
                    raise ValueError(
                        "The data type of input `attention_mask` must "
                        "be bool, int or float"
                    )
            else:
                attention_mask = paddle.concat(
                    [
                        attention_mask,
                        paddle.ones([attention_mask.shape[0], 1], dtype="int64"),
                    ],
                    axis=-1,
                )
            model_kwargs["attention_mask"] = attention_mask

        # update role_ids
        if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
            role_ids = model_kwargs["role_ids"]
            model_kwargs["role_ids"] = paddle.concat(
                [role_ids, role_ids[:, -1:]], axis=-1
            )

        return model_kwargs

    @staticmethod
    def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
        # update scores

        unfinished_scores = (
            scores * paddle.to_tensor(length, dtype=scores.dtype) + next_scores
        ) / (paddle.to_tensor(length, dtype=scores.dtype) + 1)
        scores = paddle.where(unfinished_flag, unfinished_scores, scores)
        return scores

    def prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
        if "encoder_output" not in model_kwargs:
            # retrieve encoder hidden states
            encoder = self.get_encoder()
            encoder_kwargs = {
                argument: value
                for argument, value in model_kwargs.items()
                if not (
                    argument.startswith("decoder_")
                    or argument.startswith("cross_attn")
                    or argument == "use_cache"
                )
            }
            # Use inputs_embeds as the priority if inputs_embeds exists
            if "inputs_embeds" in encoder_kwargs:
                model_kwargs["encoder_output"] = encoder(**encoder_kwargs)
            else:
                model_kwargs["encoder_output"] = encoder(
                    input_ids=input_ids, **encoder_kwargs
                )
        return model_kwargs

    def prepare_decoder_input_ids_for_generation(
        self, input_ids, decoder_start_token_id=None, bos_token_id=None
    ):
        decoder_start_token_id = (
            decoder_start_token_id
            if decoder_start_token_id is not None
            else self.config.decoder_start_token_id
        )
        decoder_start_token_id = (
            decoder_start_token_id
            if decoder_start_token_id is not None
            else bos_token_id
        )

        decoder_input_ids = (
            paddle.ones([input_ids.shape[0], 1], dtype="int64") * decoder_start_token_id
        )

        return decoder_input_ids

    def get_decoder_start_token_id(
        self, decoder_start_token_id=None, bos_token_id=None
    ):
        decoder_start_token_id = (
            decoder_start_token_id
            if decoder_start_token_id is not None
            else self.config.decoder_start_token_id
        )
        bos_token_id = (
            bos_token_id if bos_token_id is not None else self.config.bos_token_id
        )

        if decoder_start_token_id is not None:
            return decoder_start_token_id
        elif self.config.decoder_start_token_id is not None:
            return self.config.decoder_start_token_id
        elif bos_token_id is not None:
            return bos_token_id
        elif self.config.bos_token_id is not None:
            return self.config.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        # Implement in subclasses for custom behavior to prepare inputs in the
        # generate method.

        return {"input_ids": input_ids}

    def adjust_logits_during_generation(self, logits):
        # Implement in subclasses for custom behavior to adjust the logits in
        # the generate method.

        return logits

    def prepare_fast_entry(self, kwargs):
        return False

    def _convert_to_fast(self, kwargs):
        # try general convert
        pass

    def _build_fast(self, kwargs):
        self._fast_entry = False
        if kwargs["num_beam_groups"] != 1:
            # not support for group_beam_search yet in the fast version
            raise AttributeError(
                "'num_beam_groups != 1' is not supported yet in the fast version"
            )
        if (
            paddle.get_default_dtype() == "float16"
            and kwargs["use_fp16_decoding"] is False
        ):
            logging.info(
                "Since the default dtype is float16, float16 would be used "
                "though 'use_fp16_decoding=False'."
            )
            kwargs["use_fp16_decoding"] = True
        self.prepare_fast_entry(kwargs)

    def set_pad_token_id(self, pad_token_id, eos_token_id):
        if pad_token_id is None and eos_token_id is not None:
            logging.warning(
                "Setting `pad_token_id` to `eos_token_id`:{} for "
                "open-end generation.".format(eos_token_id)
            )
            if isinstance(eos_token_id, list):
                pad_token_id = eos_token_id[0]
            else:
                pad_token_id = eos_token_id
        return pad_token_id

    @paddle.no_grad()
    def generate(
        self,
        input_ids: paddle.Tensor = None,
        generation_config: GenerationConfig = None,
        stopping_criteria: StoppingCriteria = None,
        streamer=None,
        synced_gpus: Optional[bool] = None,
        **kwargs,
    ):
        r"""
        The interface for generation task. This method can generate sequences
        by using decoding strategy. Currently, there are three decoding
        strategies supported: "greedy_search", "sampling" and "beam_search".

        Args:
            input_ids (Tensor, optional): The input sequence ids for the
                generation. It is a Tensor with shape [batch_size, sequence_length].
                The data type should be int32 or int64. Default to None, which
                we will initialize it as a Tensor with shape [1, 1], filled
                with the value `bos_token_id`.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                generation config. If a stopping criteria is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            streamer (`~streamer.BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            synced_gpus (`bool`, *optional*):
                Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
                `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
                generating before other GPUs. Otherwise it'll be set to `False`.
            kwargs (dict): It can be used to specify additional kwargs
                passed to the model.

        Returns:
            tuple[Tensor]: It is a tuple contains two elements: ids and scores.
            Each element is a Tensor.

            With the fields:

            - ids (Tensor):
                The ids of the generated sequences. It is a Tensor with shape
                [batch_size * num_return_sequences, sequence_length]. The data
                type is same as the input `input_ids`.
            - scores (Tensor):
                The scores of the generated sequences. It is a Tensor with shape
                [batch_size * num_return_sequences, 1]. The data type is float32
                or float64, which is the same as the parameters in the model.

        Example:
            .. code-block::

                import paddle
                from paddlenlp.transformers import (
                    UnifiedTransformerLMHeadModel,
                    UnifiedTransformerTokenizer
                )

                paddle.seed(2)

                # Initialize the model and tokenizer
                model_name_or_path = 'unified_transformer-12L-cn-luge'
                model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
                tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)

                # Prepare the model inputs.
                history = "早上好，今天空气质量不错。"
                inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
                    add_start_token_as_response=True, return_tensors=True)

            .. code-block::

                # Generate the sequence by using "greedy_search" strategy
                ids, scores = model.generate(
                    **inputs,
                    decode_strategy="greedy_search")
                print(ids.shape, scores.shape)
                # [1, 3] [1, 1]
                sequence_ids = ids.cpu().numpy().tolist()[0]
                sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
                response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
                print(response)
                # 是的

            .. code-block::

                # Generate 2 sequences by using "sampling" strategy (top_k=5)
                generation_config = GenerationConfig(
                    decode_strategy="sampling",
                    top_k=5,
                    num_return_sequences=2
                )
                ids, scores = model.generate(
                    **inputs,
                    generation_config=generation_config,
                    )
                print(ids.shape, scores.shape)
                # [2, 7] [2, 1]
                response = []
                for sequence_ids in ids.cpu().numpy().tolist():
                    sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
                    text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
                    response.append(text)
                print(response)
                # ['天气好,心情也好', '你也是']

            .. code-block::

                # Generate 2 sequences by using "beam_search" strategy (num_beams=5)
                generation_config = GenerationConfig(
                    decode_strategy="beam_search",
                    num_beams=5,
                    num_return_sequences=2
                )
                ids, scores = model.generate(
                    **inputs,
                    generation_config=generation_config,
                    )
                print(ids.shape, scores.shape)
                # [2, 3] [2, 1]
                response = []
                for sequence_ids in ids.cpu().numpy().tolist():
                    sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
                    text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
                    response.append(text)
                print(response)
                # ['是的', '嗯嗯']
        """
        if generation_config is None:
            if self.generation_config is None or (
                self.generation_config._from_model_config
                and self.config._has_non_default_generation_parameters()
            ):
                new_generation_config = GenerationConfig.from_model_config(self.config)
                if new_generation_config != self.generation_config:
                    logging.warning(
                        "model.generation_config is in conflict with model.config, "
                        "model.config is used."
                    )
                    self.generation_config = new_generation_config
            generation_config = self.generation_config

        # without update model.generation_config
        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)

        assert generation_config.decode_strategy in [
            "greedy_search",
            "sampling",
            "beam_search",
        ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
            generation_config.decode_strategy
        )

        if getattr(self, "deprecated_warnings", None) is None:
            self.deprecated_warnings = {}

        use_fast = False
        if "use_faster" in model_kwargs:
            raise ValueError("`use_faster` is deprecated now.")

        if "use_fast" in model_kwargs:
            raise ValueError("`use_fast` is deprecated now.")

        bos_token_id = (
            generation_config.bos_token_id
            if generation_config.bos_token_id is not None
            else self.config.bos_token_id
        )
        eos_token_id = (
            generation_config.eos_token_id
            if generation_config.eos_token_id is not None
            else self.config.eos_token_id
        )
        pad_token_id = (
            generation_config.pad_token_id
            if generation_config.pad_token_id is not None
            else self.config.pad_token_id
        )
        forced_bos_token_id = (
            generation_config.forced_bos_token_id
            if generation_config.forced_bos_token_id is not None
            else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            generation_config.forced_eos_token_id
            if generation_config.forced_eos_token_id is not None
            else self.config.forced_eos_token_id
        )
        decoder_start_token_id = (
            generation_config.decoder_start_token_id
            if generation_config.decoder_start_token_id is not None
            else self.config.decoder_start_token_id
        )
        no_repeat_ngram_size = (
            generation_config.no_repeat_ngram_size
            if generation_config.no_repeat_ngram_size is not None
            else self.config.no_repeat_ngram_size
        )

        if getattr(self, "_fast_entry", None) is not False and use_fast:
            fg_args = locals()
            fg_args.pop("self")
            fg_args.pop("__class__", None)
            model_kwargs = fg_args.pop("model_kwargs")
            fg_args.update(model_kwargs)
            try:
                if getattr(self, "_fast_entry", None) is None:
                    self._build_fast(fg_args)
                if self._fast_entry:
                    output = self._fast_entry(**fg_args)
                    if isinstance(output, tuple):
                        output_ids, dummy_srore = output
                    else:
                        output_ids = output
                        # make result and fast result oneconsistent
                        dummy_srore = None
                    if generation_config.decode_strategy == "beam_search":
                        output_ids = output_ids.transpose([1, 2, 0])
                        output_ids = output_ids[
                            :, : generation_config.num_return_sequences, :
                        ].reshape([-1, output_ids.shape[-1]])
                        if dummy_srore is not None:
                            dummy_srore = dummy_srore[
                                :, : generation_config.num_return_sequences
                            ].flatten()
                    else:
                        output_ids = output_ids.transpose([1, 0])
                    return output_ids, dummy_srore

            except Exception as e:
                fg_args["model_kwargs"] = model_kwargs
                # TODO
                # Prevent self._convert_to_fast to throw Exception
                self._convert_to_fast(fg_args)
                logging.warning(e)
                logging.warning(
                    "FastGeneration is not available, "
                    "and the original version would be used instead."
                )

        # input_ids in model_kwargs is supported
        if "input_ids" in model_kwargs:
            _input_ids = model_kwargs.pop("input_ids")
            if input_ids is None:
                input_ids = _input_ids

        # params check
        if input_ids is None and "inputs_embeds" not in model_kwargs:
            # Init `input_ids` with bos_token_id
            input_ids = self.prepare_input_ids_for_generation(bos_token_id)
        elif "inputs_embeds" in model_kwargs:
            # Add input embeds support
            input_ids = self.prepare_input_ids_for_generation(
                bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
            )

        if model_kwargs.get("attention_mask", None) is None:
            # TODO
            # Init `attention_mask` depending on `pad_token_id`
            model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
                input_ids, pad_token_id, eos_token_id
            )
        self.is_encoder_decoder = self.config.is_encoder_decoder

        if self.is_encoder_decoder:
            model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
                input_ids, model_kwargs
            )
            # set input_ids as decoder_input_ids
            if "decoder_input_ids" in model_kwargs:
                input_ids = model_kwargs.pop("decoder_input_ids")
            else:
                input_ids = self.prepare_decoder_input_ids_for_generation(
                    input_ids, decoder_start_token_id, bos_token_id
                )
        # streamer
        if streamer is not None:
            # streamer couldn't support beam_search strategy
            if (
                generation_config.decode_strategy == "beam_search"
                or generation_config.num_beams > 1
            ):
                raise ValueError(
                    "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
                )

        pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)

        if (
            generation_config.max_length != 0
            and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS
        ):
            logging.warning(
                "`max_length` will be deprecated in future releases, use `max_new_tokens` instead."
            )
            generation_config.max_new_tokens = generation_config.max_length

        if generation_config.min_length != 0 and generation_config.min_new_tokens == 0:
            logging.warning(
                "`min_length` will be deprecated in future releases, use `min_new_tokens` instead."
            )
            generation_config.min_new_tokens = generation_config.min_length

        max_length = generation_config.max_new_tokens
        min_length = generation_config.min_new_tokens

        input_len = input_ids.shape[-1]
        min_len = input_len + min_length
        max_len = input_len + max_length

        logits_processors = self.get_logits_processor(
            min_length=min_len if min_length > 0 else None,
            max_length=max_len,
            eos_token_id=eos_token_id,
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
            num_beams=generation_config.num_beams,
            num_beam_groups=generation_config.num_beam_groups,
            diversity_rate=generation_config.diversity_rate,
            repetition_penalty=generation_config.repetition_penalty,
            no_repeat_ngram_size=generation_config.no_repeat_ngram_size,
            logits_processors=(
                model_kwargs["logits_processors"]
                if "logits_processors" in model_kwargs
                and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
                else None
            ),
        )
        if "logits_processors" in model_kwargs:
            model_kwargs.pop("logits_processors")

        model_kwargs["use_cache"] = generation_config.use_cache

        stopping_criteria = (
            stopping_criteria
            if stopping_criteria is not None
            else StoppingCriteriaList()
        )

        if generation_config.decode_strategy == "greedy_search":
            if generation_config.num_return_sequences > 1:
                raise ValueError(
                    "`num_return_sequences` has to be 1, but is {} "
                    "when doing greedy search.".format(
                        generation_config.num_return_sequences
                    )
                )
            return self.greedy_search(
                input_ids,
                logits_processors,
                max_len,
                pad_token_id,
                eos_token_id,
                stopping_criteria=stopping_criteria,
                streamer=streamer,
                fast_ptq_sampling=generation_config.fast_ptq_sampling,
                trunc_input=generation_config.trunc_input,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif generation_config.decode_strategy == "sampling":
            if generation_config.num_return_sequences > 1:
                input_ids, model_kwargs = self.expand_inputs_for_generation(
                    input_ids,
                    expand_size=generation_config.num_return_sequences,
                    **model_kwargs,
                )

            return self.sample(
                input_ids,
                logits_processors,
                max_len,
                pad_token_id,
                eos_token_id,
                generation_config.top_k,
                generation_config.top_p,
                generation_config.temperature,
                stopping_criteria=stopping_criteria,
                streamer=streamer,
                fast_ptq_sampling=generation_config.fast_ptq_sampling,
                trunc_input=generation_config.trunc_input,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif generation_config.decode_strategy == "beam_search":
            batch_size = input_ids.shape[0]
            if generation_config.num_return_sequences > generation_config.num_beams:
                raise ValueError(
                    "`num_return_sequences` has to be smaller or equal to "
                    "`num_beams`. But received `num_return_sequences` is {}, "
                    "`num_beams` is {}".format(
                        generation_config.num_return_sequences,
                        generation_config.num_beams,
                    )
                )
            if generation_config.num_beams <= 1:
                raise ValueError(
                    "`num_beams` has to be bigger than 1. But received "
                    "`num_beams` is {}. If `num_beams` is 1, `decode_strategy` "
                    "should be 'greedy_search'".format(generation_config.num_beams)
                )
            if generation_config.num_beam_groups > 1:
                diverse_beam_scorer = BeamSearchScorer(
                    batch_size=batch_size,
                    max_length=max_len,
                    num_beams=generation_config.num_beams,
                    length_penalty=generation_config.length_penalty,
                    do_early_stopping=generation_config.early_stopping,
                    num_beam_hyps_to_keep=generation_config.num_return_sequences,
                    num_beam_groups=generation_config.num_beam_groups,
                )

                # interleave with `num_beams`
                input_ids, model_kwargs = self.expand_inputs_for_generation(
                    input_ids, expand_size=generation_config.num_beams, **model_kwargs
                )

                return self.group_beam_search(
                    input_ids,
                    diverse_beam_scorer,
                    logits_processors,
                    max_len,
                    pad_token_id,
                    eos_token_id,
                    stopping_criteria=stopping_criteria,
                    fast_ptq_sampling=generation_config.fast_ptq_sampling,
                    trunc_input=generation_config.trunc_input,
                    synced_gpus=synced_gpus,
                    **model_kwargs,
                )
            else:
                beam_scorer = BeamSearchScorer(
                    batch_size=batch_size,
                    max_length=max_len,
                    num_beams=generation_config.num_beams,
                    length_penalty=generation_config.length_penalty,
                    do_early_stopping=generation_config.early_stopping,
                    num_beam_hyps_to_keep=generation_config.num_return_sequences,
                )

                input_ids, model_kwargs = self.expand_inputs_for_generation(
                    input_ids, expand_size=generation_config.num_beams, **model_kwargs
                )

                return self.beam_search(
                    input_ids,
                    beam_scorer,
                    logits_processors,
                    max_len,
                    generation_config.diversity_rate,
                    pad_token_id,
                    eos_token_id,
                    stopping_criteria=stopping_criteria,
                    fast_ptq_sampling=generation_config.fast_ptq_sampling,
                    trunc_input=generation_config.trunc_input,
                    synced_gpus=synced_gpus,
                    **model_kwargs,
                )

    def greedy_search(
        self,
        input_ids,
        logits_processors,
        max_length,
        pad_token_id,
        eos_token_id,
        stopping_criteria=None,
        streamer=None,
        fast_ptq_sampling=False,
        trunc_input=True,
        synced_gpus=False,
        **model_kwargs,
    ):
        logits_processors = (
            logits_processors
            if logits_processors is not None
            else LogitsProcessorList()
        )

        # max_length will be convert to MaxLengthCriteria
        stopping_criteria = (
            stopping_criteria
            if stopping_criteria is not None
            else StoppingCriteriaList()
        )
        if max_length is not None:
            # logging.warning(
            #    "`max_length` is deprecated in this function, use"
            #    " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
            # )
            stopping_criteria = validate_stopping_criteria(
                stopping_criteria, max_length
            )

        batch_size, cur_len = input_ids.shape
        origin_len = cur_len
        unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
        scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
        generate_end = False
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs & get model output
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(**model_inputs)

            if synced_gpus and generate_end:
                continue  # don't waste resources running the code we don't need

            if isinstance(outputs, tuple):
                logits = outputs[0]
            elif isinstance(outputs, ModelOutput):
                logits = outputs.logits
            else:
                logits = outputs

            # [batch_size, vocab_size]
            next_token_logits = logits[:, -1, :]

            # pre-process distribution
            next_token_logits = self.adjust_logits_during_generation(next_token_logits)
            probs = logits_processors(input_ids, next_token_logits)
            # greedy
            next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
            next_scores = paddle.index_sample(probs, next_tokens)

            if eos_token_id is not None:
                next_tokens = paddle.where(
                    unfinished_flag,
                    next_tokens,
                    paddle.full_like(next_tokens, pad_token_id),
                )

            scores = self.update_scores_for_generation(
                scores, next_scores, cur_len - origin_len, unfinished_flag
            )
            cur_len += 1

            input_ids = paddle.concat([input_ids, next_tokens], axis=1)
            if streamer is not None:
                if self.config.tensor_parallel_rank == 0:
                    streamer.put(next_tokens.cpu())

            if stopping_criteria(input_ids, scores):
                generate_end = True

            if eos_token_id is not None:
                unfinished_flag = get_unfinished_flag(
                    input_ids, unfinished_flag, eos_token_id
                )
                if not paddle.any(unfinished_flag):
                    generate_end = True

            # Stop when there is a </s> in all sentences
            if generate_end and not synced_gpus:
                break

            model_kwargs = self.update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if fast_ptq_sampling:
                break

        if streamer is not None:
            streamer.end()

        return input_ids[:, origin_len:] if trunc_input else input_ids, scores

    def sample(
        self,
        input_ids,
        logits_processors,
        max_length,
        pad_token_id,
        eos_token_id,
        top_k=None,
        top_p=None,
        temperature=None,
        min_tokens_to_keep=1,
        stopping_criteria=None,
        streamer=None,
        fast_ptq_sampling=False,
        trunc_input=True,
        synced_gpus=False,
        **model_kwargs,
    ):

        logits_processors = (
            logits_processors
            if logits_processors is not None
            else LogitsProcessorList()
        )

        # max_length will be convert to MaxLengthCriteria
        stopping_criteria = (
            stopping_criteria
            if stopping_criteria is not None
            else StoppingCriteriaList()
        )
        if max_length is not None:
            # logging.warning(
            #    "`max_length` is deprecated in this function, use"
            #    " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
            # )
            stopping_criteria = validate_stopping_criteria(
                stopping_criteria, max_length
            )

        batch_size, cur_len = input_ids.shape
        origin_len = cur_len
        unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
        scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())

        generate_end = False
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break
            # prepare model inputs & get model output
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # NOTE: to decrease ref-count and clear outdate cache in-time
            model_kwargs["cache"] = None
            model_kwargs["past_key_values"] = None
            outputs = self(**model_inputs)
            if synced_gpus and generate_end:
                continue  # don't waste resources running the code we don't need

            if isinstance(outputs, tuple):
                logits = outputs[0]
            elif isinstance(outputs, ModelOutput):
                logits = outputs.logits
            else:
                logits = outputs

            # [batch_size, vocab_size]
            logits = logits[:, -1, :]

            # pre-process distribution
            logits = self.adjust_logits_during_generation(logits)
            logits = logits_processors(input_ids, logits)

            # sample
            origin_probs = F.softmax(logits)
            origin_probs = paddle.log(origin_probs)
            if temperature is not None and temperature != 1.0:
                logits = logits / temperature
            probs = F.softmax(logits)
            if top_k is not None and top_k != 0:
                probs = TopKProcess(probs, top_k, min_tokens_to_keep)
            if top_p is not None and top_p < 1.0:
                probs = TopPProcess(probs, top_p, min_tokens_to_keep)
            if paddle.device.is_compiled_with_custom_device("gcu"):
                probs = paddle.cast(probs, "float32")
            if paddle.device.is_compiled_with_xpu():
                probs = paddle.cast(probs, "float32")

            # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
            next_tokens = paddle.multinomial(probs)

            if self.config.tensor_parallel_degree > 1:
                # Maybe no need to broadcast if seed is set correctly.
                from paddle.distributed import fleet

                try:
                    hcg = fleet.get_hybrid_communicate_group()
                    group = hcg.get_model_parallel_group()
                    src = hcg.get_model_parallel_group_src_rank()
                except:
                    group, src = None, 0
                paddle.distributed.broadcast(next_tokens, src=src, group=group)
            # config does not include pipeline_parallel_degree, and pipeline parallel
            # uses trainer.model_wrapped to run in both train and predict mode
            # which has pp_group as a attribute
            # TODO(guosheng): only let the last stage of pipeline to do softmax
            # and sampling, and then broadcast to avoid broadcast logits.
            if getattr(self, "pp_group", None) is not None:
                paddle.distributed.broadcast(
                    next_tokens,
                    src=self.pp_group.ranks[0],
                    group=self.pp_group,  # use rank 0 for same seed to check
                )

            next_scores = paddle.index_sample(origin_probs, next_tokens)
            if eos_token_id is not None:
                next_tokens = paddle.where(
                    unfinished_flag,
                    next_tokens,
                    paddle.full_like(next_tokens, pad_token_id),
                )

            scores = self.update_scores_for_generation(
                scores, next_scores, cur_len - origin_len, unfinished_flag
            )

            cur_len += 1
            input_ids = paddle.concat([input_ids, next_tokens], axis=1)
            if streamer is not None:
                if self.config.tensor_parallel_rank == 0:
                    streamer.put(next_tokens.cpu())

            if stopping_criteria(input_ids, scores):
                generate_end = True

            if eos_token_id is not None:
                unfinished_flag = get_unfinished_flag(
                    input_ids, unfinished_flag, eos_token_id
                )
                if not paddle.any(unfinished_flag):
                    generate_end = True

            # Stop when there is a </s> in all sentences
            if generate_end and not synced_gpus:
                break

            model_kwargs = self.update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
            )
            if fast_ptq_sampling:
                break

        if streamer is not None:
            streamer.end()

        return input_ids[:, origin_len:] if trunc_input else input_ids, scores

    def _get_model_inputs_spec(self, dtype: str):
        spec = {
            "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
            "attention_mask": paddle.static.InputSpec(
                shape=[None, None], dtype="int64"
            ),
        }
        if "position_ids" in inspect.getfullargspec(self.forward).args:
            spec["position_ids"] = paddle.static.InputSpec(
                shape=[None, None], dtype="int64"
            )
        return spec

    def to_static(self, path: str, config: dict):
        """export generation model to static

        Args:
            path (str): path of saved inference model
            config (dict): configuration for generation
                bos_token_id (int): token id of begin-of-sentence
                eos_token_id (int): token id of end-of-sentence
                pad_token_id (int): token id of pad token
                use_top_p (bool): whether use top_p decoding strategy
        """

        use_top_p = config.get("use_top_p", True)

        top_k_spec = (
            paddle.static.InputSpec(shape=[1], dtype="int64") if not use_top_p else 0
        )

        top_p_spec = (
            paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
        )
        temperature = (
            paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
        )
        dtype = config.get("dtype", None)

        logits_processors = config.get("logits_processors", None)
        model_inputs_spec = self._get_model_inputs_spec(dtype)

        input_spec = [
            model_inputs_spec["input_ids"],  # input_ids
            model_inputs_spec["attention_mask"],  # attention_mask
            model_inputs_spec.get("position_ids", None),  # attention_mask
            logits_processors,
            paddle.static.InputSpec(shape=[1], dtype="int64"),  # max_length
            self.generation_config.pad_token_id or config.get("pad_token_id", None),
            self.generation_config.eos_token_id or config.get("eos_token_id", None),
            top_k_spec,  # top_k
            top_p_spec,  # top_p
            temperature,  # temperature
            1,
        ]

        model = paddle.jit.to_static(self.sample_d2s, input_spec=input_spec)

        paddle.jit.save(model, path)

    def sample_d2s(
        self,
        input_ids,
        attention_mask,
        position_ids,
        logits_processors,
        max_new_tokens,
        pad_token_id,
        eos_token_id,
        top_k=None,
        top_p=None,
        temperature=None,
        min_tokens_to_keep=1,
    ):

        pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
        logits_processors = (
            logits_processors
            if logits_processors is not None
            else LogitsProcessorList()
        )

        if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
            use_top_p = False
        elif not paddle.is_tensor(top_k) and paddle.is_tensor(top_p):
            use_top_p = True

        # top_k and top_p are the const value
        elif isinstance(top_p, float) or isinstance(top_k, int):
            use_top_p = True
        else:
            if top_p is None and top_k is None:
                raise ValueError("top_k and top_p should not be None")
            raise ValueError(
                "you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
            )

        batch_size, cur_len = input_ids.shape
        # used for compute on gpu, avoid memcpy D2H
        cur_len_gpu = paddle.full([1], cur_len, dtype="int64")

        origin_len = input_ids.shape[1]
        # used for compute on gpu, avoid memcpy D2H
        origin_len_gpu = paddle.full([1], origin_len, dtype="int64")

        unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")

        scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())

        # use_cache is immutable, we split it off other mutable kwargs.
        immutable = {"use_cache": True}
        model_kwargs = {"attention_mask": attention_mask, "position_ids": position_ids}

        def _forward_(**args):
            model_inputs = self.prepare_inputs_for_generation(
                input_ids, **args, **immutable
            )
            assert "use_cache" in model_inputs
            del model_inputs["use_cache"]
            return self(**model_inputs, **immutable)

        def _post_process_(
            outputs,
            input_ids,
            cur_len,
            origin_len,
            scores,
            unfinished_flag,
            model_kwargs,
            pad_token_id,
        ):
            if isinstance(outputs, tuple):
                logits = outputs[0]
            elif isinstance(outputs, ModelOutput):
                logits = outputs.logits
            else:
                logits = outputs

            # [batch_size, vocab_size]
            logits = logits[:, -1, :]

            # pre-process distribution
            logits = self.adjust_logits_during_generation(logits)

            logits = logits_processors(input_ids, logits)
            probs = F.softmax(logits)

            # sample
            origin_probs = F.log_softmax(logits)
            # compute next_tokens
            if use_top_p:
                logits = logits / temperature
                top_ps_tensor = paddle.full(
                    shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype
                )
                _, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
            else:
                probs = TopKProcess(probs, top_k, min_tokens_to_keep)
                if top_k == 1:
                    next_tokens = paddle.unsqueeze_(paddle.argmax(probs, axis=-1), -1)
                else:
                    next_tokens = paddle.multinomial(probs)

            next_scores = paddle.index_sample(origin_probs, next_tokens)
            scores = self.update_scores_for_generation(
                scores, next_scores, cur_len - origin_len, unfinished_flag
            )
            if eos_token_id is not None:
                next_tokens = paddle.where(
                    unfinished_flag,
                    next_tokens,
                    paddle.full_like(next_tokens, pad_token_id),
                )

            input_ids = paddle.concat([input_ids, next_tokens], axis=1)

            if eos_token_id is not None:
                unfinished_flag = get_unfinished_flag(
                    input_ids, unfinished_flag, eos_token_id
                )

            model_kwargs = self.update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )

            return input_ids, scores, unfinished_flag, model_kwargs

        outputs = _forward_(**model_kwargs)
        input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
            outputs,
            input_ids,
            cur_len_gpu,
            origin_len_gpu,
            scores,
            unfinished_flag,
            model_kwargs,
            pad_token_id,
        )

        cur_len += 1
        cur_len_gpu += 1

        attn_mask = model_kwargs["attention_mask"]
        # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
        model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
        model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
        max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")

        while cur_len < max_new_tokens and paddle.any(unfinished_flag):
            input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
                _forward_(**model_kwargs),
                input_ids,
                cur_len_gpu,
                origin_len_gpu,
                scores,
                unfinished_flag,
                model_kwargs,
                pad_token_id,
            )
            cur_len += 1
            cur_len_gpu += 1

        return input_ids[:, origin_len:], scores

    def reorder_cache(self, cache, beam_idx):
        cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
        return cache

    def beam_search(
        self,
        input_ids,
        beam_scorer,
        logits_processors,
        max_length,
        diversity_rate,
        pad_token_id,
        eos_token_id,
        stopping_criteria=None,
        fast_ptq_sampling=False,
        trunc_input=True,
        synced_gpus=False,
        **model_kwargs,
    ):
        logits_processors = (
            logits_processors
            if logits_processors is not None
            else LogitsProcessorList()
        )

        # max_length will be convert to MaxLengthCriteria
        stopping_criteria = (
            stopping_criteria
            if stopping_criteria is not None
            else StoppingCriteriaList()
        )
        if max_length is not None:
            # logging.warning(
            #    "`max_length` is deprecated in this function, use"
            #    " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
            # )
            stopping_criteria = validate_stopping_criteria(
                stopping_criteria, max_length
            )

        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        batch_beam_size, cur_len = input_ids.shape
        origin_len = cur_len

        assert (
            num_beams * batch_size == batch_beam_size
        ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
            num_beams * batch_size, batch_beam_size
        )

        beam_scores = paddle.zeros(
            (batch_size, num_beams), dtype=paddle.get_default_dtype()
        )

        beam_scores[:, 1:] = get_scale_by_dtype(return_positive=False)
        beam_scores = paddle.reshape(beam_scores, [-1])

        generate_end = False
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break
            # prepare model inputs & get model output
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(**model_inputs)
            if synced_gpus and generate_end:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            if isinstance(outputs, tuple):
                logits = outputs[0]
            elif isinstance(outputs, ModelOutput):
                logits = outputs.logits
            else:
                logits = outputs

            # [batch_size, vocab_size]
            logits = logits[:, -1, :]

            # pre-process distribution
            logits = self.adjust_logits_during_generation(logits)
            # beam search
            # [batch_size * num_beams, vocab_size]
            next_scores = F.softmax(logits)
            next_scores = paddle.log(next_scores)
            next_scores = logits_processors(input_ids, next_scores)
            next_scores = next_scores + beam_scores.unsqueeze(-1)

            vocab_size = next_scores.shape[-1]
            if diversity_rate == 0.0:
                # reshape for beam search
                next_scores = next_scores.reshape([batch_size, num_beams * vocab_size])

                next_scores, next_tokens = paddle.topk(
                    next_scores, 2 * num_beams, axis=1
                )

                next_indices = next_tokens // vocab_size
                next_tokens = next_tokens % vocab_size

            else:
                next_scores, next_tokens = paddle.topk(
                    next_scores, 2 * num_beams, axis=1
                )

                sibling_score = (
                    paddle.arange(1, 2 * num_beams + 1, dtype="int64").unsqueeze(0)
                    * diversity_rate
                )

                diversed_score = next_scores - sibling_score

                next_scores = next_scores.reshape(
                    [batch_size, 2 * num_beams * num_beams]
                )
                next_tokens = next_tokens.reshape(
                    [batch_size, 2 * num_beams * num_beams]
                )

                diversed_score = diversed_score.reshape(
                    [batch_size, 2 * num_beams * num_beams]
                )
                diversed_score, diversed_tokens = paddle.topk(
                    diversed_score, 2 * num_beams, axis=1
                )

                # TODO
                # Use gather_nd() to select origan token and score
                next_scores = paddle.stack(
                    [
                        paddle.index_select(next_scores[i], diversed_tokens[i])
                        for i in range(next_scores.shape[0])
                    ]
                )
                next_tokens = paddle.stack(
                    [
                        paddle.index_select(next_tokens[i], diversed_tokens[i])
                        for i in range(next_tokens.shape[0])
                    ]
                )

                next_indices = diversed_tokens // (2 * num_beams)

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_scores,
                next_tokens,
                next_indices,
                origin_len=origin_len,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
            # beam_idx may contain element -1 and cause error
            # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
            beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))

            cur_len += 1
            input_ids = paddle.concat(
                [
                    paddle.index_select(input_ids, beam_idx),
                    beam_next_tokens.unsqueeze(-1),
                ],
                axis=-1,
            )

            if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
                if not synced_gpus:
                    break
                else:
                    generate_end = True

            model_kwargs = self.update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
            )
            if "cache" in model_kwargs:
                # reorder the cache
                model_kwargs["cache"] = self.reorder_cache(
                    model_kwargs["cache"], beam_idx
                )
            if "past_key_values" in model_kwargs:
                # reorder the cache
                model_kwargs["past_key_values"] = self.reorder_cache(
                    model_kwargs["past_key_values"], beam_idx
                )
            if fast_ptq_sampling:
                break

        pred_ids, scores = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            origin_len=origin_len,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
        )
        return pred_ids[:, origin_len:] if trunc_input else input_ids, scores

    def group_beam_search(
        self,
        input_ids,
        beam_scorer,
        logits_processors,
        max_length,
        pad_token_id,
        eos_token_id,
        stopping_criteria=None,
        fast_ptq_sampling=False,
        trunc_input=True,
        synced_gpus=False,
        **model_kwargs,
    ):
        logits_processors = (
            logits_processors
            if logits_processors is not None
            else LogitsProcessorList()
        )

        # max_length will be convert to MaxLengthCriteria
        stopping_criteria = (
            stopping_criteria
            if stopping_criteria is not None
            else StoppingCriteriaList()
        )
        if max_length is not None:
            # logging.warning(
            #    "`max_length` is deprecated in this function, use"
            #    " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
            # )
            stopping_criteria = validate_stopping_criteria(
                stopping_criteria, max_length
            )

        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups

        batch_beam_size, cur_len = input_ids.shape
        origin_len = cur_len

        assert (
            num_beams * batch_size == batch_beam_size
        ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
            num_beams * batch_size, batch_beam_size
        )

        beam_scores = paddle.full(
            (batch_size, num_beams),
            get_scale_by_dtype(return_positive=False),
            dtype="float32",
        )
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = paddle.reshape(beam_scores, [-1])

        generate_end = False
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break
            # predicted tokens in cur_len step
            current_tokens = paddle.zeros(
                shape=[batch_size * num_beams], dtype=input_ids.dtype
            )

            # indices which will form the beams in the next time step
            reordering_indices = paddle.zeros(
                shape=[batch_size * num_beams], dtype="int64"
            )
            # prepare model inputs & get model output
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            outputs = self(**model_inputs)
            if synced_gpus and generate_end:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []

                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [
                            batch_idx * num_beams + idx
                            for idx in range(group_start_idx, group_end_idx)
                        ]
                    )

                group_input_ids = input_ids[batch_group_indices]

                if isinstance(outputs, tuple):
                    logits = outputs[0]
                elif isinstance(outputs, ModelOutput):
                    logits = outputs.logits
                else:
                    logits = outputs

                logits = logits[:, -1, :]
                logits = paddle.index_select(
                    logits, paddle.to_tensor(batch_group_indices)
                )
                logits = self.adjust_logits_during_generation(logits)

                next_scores = F.softmax(logits)
                next_scores = paddle.log(next_scores)
                vocab_size = next_scores.shape[-1]

                next_scores = logits_processors(
                    group_input_ids,
                    next_scores,
                    current_tokens=current_tokens,
                    beam_group_idx=beam_group_idx,
                )

                next_scores = next_scores + beam_scores[batch_group_indices].unsqueeze(
                    -1
                )

                # reshape for beam search
                next_scores = next_scores.reshape([batch_size, group_size * vocab_size])

                next_scores, next_tokens = paddle.topk(
                    next_scores, 2 * group_size, axis=1
                )

                next_indices = next_tokens // vocab_size
                next_tokens = next_tokens % vocab_size

                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_scores,
                    next_tokens,
                    next_indices,
                    origin_len=origin_len,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                )

                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]
                # beam_idx may contain element -1 and cause error
                # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
                beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))

                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = paddle.concat(
                    [
                        paddle.index_select(group_input_ids, index=beam_idx),
                        beam_next_tokens.unsqueeze(-1),
                    ],
                    axis=-1,
                )
                current_tokens[batch_group_indices] = beam_next_tokens

                reordering_indices[batch_group_indices] = (
                    num_beams * (beam_idx // group_size)
                    + group_start_idx
                    + (beam_idx % group_size)
                )

            input_ids = paddle.concat(
                [input_ids, current_tokens.unsqueeze(-1)], axis=-1
            )

            cur_len += 1

            if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
                if not synced_gpus:
                    break
                else:
                    generate_end = True

            model_kwargs = self.update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
            )

            if "cache" in model_kwargs:
                # reorder the cache
                model_kwargs["cache"] = self.reorder_cache(
                    model_kwargs["cache"], reordering_indices
                )
            if "past_key_values" in model_kwargs:
                # reorder the cache
                model_kwargs["past_key_values"] = self.reorder_cache(
                    model_kwargs["past_key_values"], reordering_indices
                )

            if fast_ptq_sampling:
                break

        pred_ids, scores = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            origin_len=origin_len,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
        )
        return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
