# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license which can be found at
# https://github.com/facebookresearch/fairseq/blob/main/LICENSE

import math
from typing import List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from .token_generation_constraints import (ConstraintState,
                                           OrderedConstraintState,
                                           UnorderedConstraintState)


class Search(nn.Module):

    def __init__(self, tokenizer):
        super().__init__()
        self.pad = tokenizer.pad_token_id
        self.unk = tokenizer.unk_token_id
        self.eos = tokenizer.eos_token_id
        tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()}
        added = {
            value: key
            for key, value in tokenizer.get_added_vocab().items()
        }
        tgt_dict.update(added)
        self.vocab_size = len(tgt_dict)
        self.src_lengths = torch.tensor(-1)
        self.supports_constraints = False
        self.stop_on_max_len = False

    def step(self,
             step,
             lprobs,
             scores,
             prev_output_tokens=None,
             original_batch_idxs=None):
        """Take a single search step.

        Args:
            step: the current search step, starting at 0
            lprobs: (bsz x input_beam_size x vocab_size)
                the model's log-probabilities over the vocabulary at the current step
            scores: (bsz x input_beam_size x step)
                the historical model scores of each hypothesis up to this point
            prev_output_tokens: (bsz x step)
                the previously generated oputput tokens
            original_batch_idxs: (bsz)
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Return: A tuple of (scores, indices, beams) where:
            scores: (bsz x output_beam_size)
                the scores of the chosen elements; output_beam_size can be
                larger than input_beam_size, e.g., we may return
                2*input_beam_size to account for EOS
            indices: (bsz x output_beam_size)
                the indices of the chosen elements
            beams: (bsz x output_beam_size)
                the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
        """
        raise NotImplementedError

    @torch.jit.export
    def set_src_lengths(self, src_lengths):
        self.src_lengths = src_lengths

    @torch.jit.export
    def init_constraints(self, batch_constraints: Optional[Tensor],
                         beam_size: int):
        """Initialize constraint states for constrained decoding (if supported).

        Args:
            batch_constraints: (torch.Tensor, optional)
                the list of constraints, in packed form
            beam_size: (int)
                the beam size
        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        pass

    def prune_sentences(self, batch_idxs: Tensor):
        """
        Removes constraint states for completed sentences (if supported).
        This is called from sequence_generator._generate() when sentences are
        deleted from the batch.

        Args:
            batch_idxs: Indices of *sentences* whose constraint state should be *kept*.
        """
        pass

    def update_constraints(self, active_hypos: Tensor):
        """
        Updates the constraint states by selecting the beam items that are retained.
        This is called at each time step of sequence_generator._generate() when
        the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size.

        Args:
            active_hypos: (batch size, beam size)
              list of integers denoting, for each sentence, which beam candidate items
              should be kept.
        """
        pass


class BeamSearch(Search):
    r"""
    Beam search strategy.

    step 1. Calculate top k candidates in model's log-probability under descending order. While k is the minor of
        `beam_size * 2` and `beam_size * vocabulary_size`.
    step 2. Modify hypothesis score, relative indices, beam indices for the final result.
    """

    def __init__(self, tgt_dict):
        super().__init__(tgt_dict)
        self.constraint_states = None

    @torch.jit.export
    def step(
        self,
        step: int,
        lprobs,
        scores: Optional[Tensor],
        prev_output_tokens: Optional[Tensor] = None,
        original_batch_idxs: Optional[Tensor] = None,
    ):
        r"""
        Take a single search step.

        Args:
            step (`int`):  Current step, start with 0.
            lprobs (`Tensor` with size `(bsz, input_beam_size, vocab_size)`):
                the model's log-probabilities over the vocabulary at the current step.
            scores (`Tensor` with size `(bsz, input_beam_size, step - 1)`):
                Previous sampling scores for each beam.
            prev_output_tokens (`Tensor`, **optional**. default to `None`):
                Previous output tokens, no usage in this function, will be deprecated in next version.
            original_batch_idxs (`Tensor`, **optional**, default to `None`):
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Returns: A tuple of (scores_buf, indices_buf, beams_buf), where:
            scores_buf (`Tensor` with size `(bsz, output_beam_size)`):
                The model's log-probabilities over the elements selected to sample from.
                `output_beam_size` is the minor of `2 * input_beam_size` and `vocab_size - 1`.
                which cumulates the score before.
            indices_buf (`Tensor` with size `(bsz, output_beam_size)`):
                The indices of chosen elements.
            beams_buf (`Tensor` with size `(bsz, output_beam_size)`):
                The indices of each beam.
        """
        bsz, beam_size, vocab_size = lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()
        else:
            # make probs contain cumulative scores for each hypothesis
            assert scores is not None
            lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)

        top_prediction = torch.topk(
            lprobs.view(bsz, -1),
            k=min(
                # Take the best 2 x beam_size predictions. We'll choose the first
                # beam_size of these which don't predict eos to continue with.
                beam_size * 2,
                lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
            ),
        )
        scores_buf = top_prediction[0]
        indices_buf = top_prediction[1]
        # Project back into relative indices and beams
        beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor')
        indices_buf = indices_buf.fmod(vocab_size)

        # At this point, beams_buf and indices_buf are single-dim and contain relative indices
        return scores_buf, indices_buf, beams_buf


class PrefixConstrainedBeamSearch(Search):
    r"""
    Prefix constrained beam search.

    step 1. Calculate a mask according to a `prefix_allowed_tokens_fn`
        function with input of previous hypothesis tokens and indices.
    step 2. Calculate a candidate set of `lprobs` with `lprobs` and mask produced in step 1.
    step 3. Just like beam search strategy to generate the hypothesis token.
        And the difference is the k in top k function is the minor of `beam_size` and `vocab_size -1`
    """

    def __init__(self, tgt_dict, prefix_allowed_tokens_fn):
        super().__init__(tgt_dict)
        self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
        self.stop_on_max_len = True

    @torch.jit.export
    def apply_mask(self, x, prev_output_tokens, original_batch_idxs):
        beam_size = x.shape[0] // original_batch_idxs.shape[0]
        original_batch_idxs = (
            original_batch_idxs.unsqueeze(-1).repeat(
                (1, beam_size)).flatten().tolist())

        mask = torch.full_like(x, -math.inf)
        for sent_i, (sent, batch_i) in enumerate(
                zip(prev_output_tokens, original_batch_idxs)):
            mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0

        return mask

    @torch.jit.export
    def step(
        self,
        step: int,
        lprobs: Tensor,
        scores: Tensor,
        prev_output_tokens: Tensor,
        original_batch_idxs: Tensor,
    ):
        r"""
        Take a single search step.

        Args:
            step (`int`):  Current step, start with 0.
            lprobs (`Tensor` with size `(bsz, input_beam_size, vocab_size)`):
                the model's log-probabilities over the vocabulary at the current step.
            scores (`Tensor` with size `(bsz, input_beam_size, step - 1)`):
                Previous sampling scores for each beam.
            prev_output_tokens (`Tensor`, **optional**. default to `None`):
                Previous output tokens, no usage in this function, will be deprecated in next version.
            original_batch_idxs (`Tensor`, **optional**, default to `None`):
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Returns: A tuple of (scores_buf, indices_buf, beams_buf), where:
            scores_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The model's log-probabilities over the elements selected to sample from.
                which cumulates the score before.
            indices_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The indices of chosen elements.
            beams_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The indices of each beam.
        """
        bsz, beam_size, vocab_size = lprobs.size()

        lprobs += self.apply_mask(
            lprobs.view(bsz * beam_size, 1, vocab_size),
            prev_output_tokens,
            original_batch_idxs,
        ).view(bsz, beam_size, vocab_size)

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()
        else:
            # make probs contain cumulative scores for each hypothesis
            assert scores is not None
            lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)

        top_prediction = torch.topk(
            lprobs.view(bsz, -1),
            k=min(
                # Take the best beam_size predictions. We'll choose the first
                # beam_size of these which don't predict eos to continue with.
                beam_size,
                lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
            ),
        )
        scores_buf = top_prediction[0]
        indices_buf = top_prediction[1]
        beams_buf = indices_buf // vocab_size
        indices_buf = indices_buf.fmod(vocab_size)
        return scores_buf, indices_buf, beams_buf


class LexicallyConstrainedBeamSearch(Search):
    """Implements lexically constrained beam search as described in

        Fast Lexically Constrained Decoding with Dynamic Beam
        Allocation for Neural Machine Translation.  Post & Vilar,
        NAACL 2018.  https://www.aclweb.org/anthology/N18-1119/

    and

        Improved Lexically Constrained Decoding for Translation and
        Monolingual Rewriting. Hu et al, NAACL
        2019. https://www.aclweb.org/anthology/N19-1090/

    This is accomplished by maintaining, for each beam hypothesis, a
    ConstraintState object (see constraints.py) that tracks which
    constraints have been generated and using this information to
    shape the beam for each input sentence.
    """

    def __init__(self, tokenizer, representation):
        super().__init__(tokenizer)
        self.representation = representation
        tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()}
        added = {
            value: key
            for key, value in tokenizer.get_added_vocab().items()
        }
        tgt_dict.update(added)
        self.vocab_size = len(tgt_dict)
        self.num_cands = 0
        self.supports_constraints = True

    @torch.jit.export
    def init_constraints(self, batch_constraints: Optional[Tensor],
                         beam_size: int):
        self.constraint_states = []
        for constraint_tensor in batch_constraints:
            if self.representation == 'ordered':
                constraint_state = OrderedConstraintState.create(
                    constraint_tensor)
            elif self.representation == 'unordered':
                constraint_state = UnorderedConstraintState.create(
                    constraint_tensor)

            self.constraint_states.append(
                [constraint_state for i in range(beam_size)])

    @torch.jit.export
    def prune_sentences(self, batch_idxs: Tensor):
        self.constraint_states = [
            self.constraint_states[i] for i in batch_idxs.tolist()
        ]

    @torch.jit.export
    def update_constraints(self, active_hypos: Tensor):
        if self.constraint_states:
            batch_size = active_hypos.size(0)
            for sentid in range(batch_size):
                self.constraint_states[sentid] = [
                    self.constraint_states[sentid][i]
                    for i in active_hypos[sentid]
                ]

    @torch.jit.export
    def step(
        self,
        step: int,
        lprobs: Tensor,
        scores: Optional[Tensor],
        prev_output_tokens: Optional[Tensor] = None,
        original_batch_idxs: Optional[Tensor] = None,
    ):
        """
        A constrained step builds a large candidates list from the following:
        - the top 2 * {beam_size} items over the whole beam
        - for each item in the beam
          - the top {each_k} (default 1)
          - all next constraints
        We then compute the constrained state of each beam item, and assign
        stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so
        on. We then sort by (stripe, score), and truncate the list at
        2 * beam size.

        Args:
            step: the decoder step
            lprobs: (batch size, beam size, target vocab)
                the target-vocab distributions for each item in the beam.
        Return: A tuple of (scores, indices, beams, constraints) where:
            scores: (batch, output beam size)
                the scores of the chosen elements
            indices: (batch, output beam size)
                the target vocab indices of the chosen elements
            beams: (batch, output beam size)
                the 0-indexed hypothesis ids of the chosen elements
            constraints: (batch, output beam size)
                the new constraint states
        """
        each_k = 1
        device = lprobs.device

        batch_size, beam_size, vocab_size = lprobs.size()

        self.num_cands = min(
            # Just take the k-best. We'll get another k from the 1-best from each
            # row, plus more from the constraints
            beam_size * 2,
            lprobs.view(batch_size, -1).size(1)
            - 1,  # -1 so we never select pad
        )

        # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items
        constraint_states = self.constraint_states
        if constraint_states and step > 0:
            not_finished_indices = []
            for sentno, sent_constraints in enumerate(constraint_states):
                for beamno, state in enumerate(sent_constraints):
                    index = sentno * beam_size + beamno
                    if not state.finished:
                        not_finished_indices.append(index)
            not_finished_indices = torch.tensor(not_finished_indices)
            if not_finished_indices.numel() > 0:
                lprobs.view(batch_size * beam_size, -1)[not_finished_indices,
                                                        self.eos] = -math.inf

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam entry for each batch item
            lprobs = lprobs[:, ::beam_size, :].contiguous()
        else:
            # make probs contain cumulative scores for each hypothesis
            assert scores is not None
            lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)

        top_prediction = torch.topk(
            lprobs.view(batch_size, -1),
            self.num_cands,
        )
        scores_buf, indices_buf = top_prediction
        # Project back into relative indices and beams
        beams_buf = indices_buf // vocab_size
        indices_buf = indices_buf.fmod(vocab_size)

        # Short circuit if there are no constraints in this batch
        if not constraint_states:
            return scores_buf, indices_buf, beams_buf

        # STEP 1: get top-1 from each hypothesis across all sentences in the batch
        if step > 0:
            top_scores, top_indices = torch.topk(
                lprobs.view(batch_size * beam_size, -1),
                k=each_k,
                dim=1,
            )
            top_scores = top_scores.view(batch_size, -1)
            top_indices = top_indices.view(batch_size, -1)
            scores_buf = torch.cat((scores_buf, top_scores), dim=1)
            indices_buf = torch.cat((indices_buf, top_indices), dim=1)
            new_beams = torch.arange(
                0, beam_size, device=device).repeat(batch_size, 1)
            beams_buf = torch.cat((beams_buf, new_beams), dim=1)

        # Now, process sentences in the batch one by one.
        new_scores_buf = torch.zeros((batch_size, 2 * beam_size),
                                     device=device)
        new_indices_buf = torch.zeros((batch_size, 2 * beam_size),
                                      device=device).long()
        new_beams_buf = torch.zeros((batch_size, 2 * beam_size),
                                    device=device).long()
        for sentno, states in enumerate(constraint_states):
            scores, indices, beams, new_states = self.step_sentence(
                step,
                sentno,
                lprobs[sentno],
                constraint_states[sentno],
                beams_buf[sentno].clone(),
                indices_buf[sentno].clone(),
                scores_buf[sentno].clone(),
            )
            new_scores_buf[sentno] = scores
            new_indices_buf[sentno] = indices
            new_beams_buf[sentno] = beams
            self.constraint_states[sentno] = new_states

        return new_scores_buf, new_indices_buf, new_beams_buf

    @torch.jit.export
    def step_sentence(
        self,
        step: int,
        sentno: int,
        lprobs: Tensor,
        constraint_states: List[List[ConstraintState]],
        beams_buf: Tensor,
        indices_buf: Tensor,
        scores_buf: Tensor,
    ):
        """Does per-sentence processing. Adds all constraints for each
        hypothesis to the list of candidates; then removes duplicates,
        sorts, and dynamically stripes across the banks. All tensor inputs
        are collapsed to those pertaining to a single input sentence.
        """
        device = lprobs.device

        # STEP 2: Add all constraints for each beam item
        for beamno, state in enumerate(constraint_states):
            next_tokens = torch.tensor(
                list(state.next_tokens()), device=device).long()
            if next_tokens.numel() != 0:
                indices_buf = torch.cat((indices_buf, next_tokens))
                next_beams = (
                    torch.tensor(beamno, device=device).repeat(
                        next_tokens.size(0)).long())
                beams_buf = torch.cat((beams_buf, next_beams))
                next_values = lprobs[beamno].take(next_tokens.view(-1))
                scores_buf = torch.cat((scores_buf, next_values))

            # At the 0th time step, there is just one beam item
            if step == 0:
                break

        # STEP 3: Compute the "bank" for each candidate. This is the
        # number of constraints it's generated. We need this so that
        # we can do round-robin allocation of the beam across these
        # banks. If C is the number of constraints, we select the best
        # item in bank C, then the best in bank C-1, etc, followed by
        # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so
        # on, until the maximum beam size. We accomplish this by
        # creating a sort key and striping across the banks.

        # Compute the new states for all candidates
        cands_size = indices_buf.size(0)
        constraint_states = [
            constraint_states[beams_buf[i]].advance(indices_buf[i])
            for i in range(cands_size)
        ]

        banks = torch.tensor([state.bank for state in constraint_states],
                             device=device)

        # STEP 4: Sort
        num_constraint_tokens = len(state.tokens)

        # Sort by keys (bank, score) (i.e., sort banks together, and scores
        # within banks). AFAIK pytorch doesn't support either stable sort or
        # multi-key sorting, so we have to hack this.
        MAX_SCORE = -100
        sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf
        sort_values, sort_indices = sort_key.sort(dim=0, descending=True)
        scores_buf = scores_buf[sort_indices]
        indices_buf = indices_buf[sort_indices]
        beams_buf = beams_buf[sort_indices]
        banks = banks[sort_indices]

        # Sort the constraints to follow suit
        constraint_states = [constraint_states[i] for i in sort_indices]

        # STEP 5: Remove duplicates. The topk calls (overall and
        # per-row) plus the per-row generation of constraints will
        # produce duplicates. Here we remove them.

        def roll(t):
            """Rolls a 1d tensor left by 1.

            [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3]
            """
            return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0)

        # We map candidates (beam, token_id) to a single dimension.
        # This is then shifted by 1. We can then easily identify
        # duplicates and create a mask that identifies unique
        # extensions.
        uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf
        uniques_mask = roll(uniques_mask) != uniques_mask

        # Use the mask to pare down the data structures
        scores_buf = torch.masked_select(scores_buf, uniques_mask)
        indices_buf = torch.masked_select(indices_buf, uniques_mask)
        beams_buf = torch.masked_select(beams_buf, uniques_mask)
        banks = torch.masked_select(banks, uniques_mask)
        i = 1
        for mask in uniques_mask[1:]:
            if not mask:
                constraint_states.pop(i)
            i += mask

        # STEP 6: Assign IDs round-robin across banks, sort, and
        # truncate. Now that the candidates are sorted by (bank,
        # score) and uniqed, we dynamically allocate the {beam_size}
        # beam by striping across the candidates. These stripes will
        # be used as sort keys to do round-robin selection. This is
        # accomplished in a single pass with offsets. Sorting by
        # highest-banks (furthest-along hypotheses) first ensures
        # progress through the constraints.
        #
        # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0
        # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1
        # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7
        #            = 0 5 10 1 6 11 13 2 7 12 3 8
        #
        # Sorting by this then gives the following banks:
        #
        #             3 2 1 0 3 2 1 0 3 2 1 2
        #
        # We'll take the top {beam_size} of these.
        stripe_offsets = [
            offset * (len(banks) + 1) for offset in range(len(banks) + 1)
        ]
        stripes = torch.zeros_like(banks)
        cur_bank_count = -1
        cur_bank = banks[0]
        for i, bank in enumerate(banks):
            if bank != cur_bank:
                cur_bank_count = 0
                cur_bank = bank
            else:
                cur_bank_count += 1
            stripes[i] = num_constraint_tokens - bank + stripe_offsets[
                cur_bank_count]

        # STEP 7: Sort by the stripes values
        sort_values, sort_indices = stripes.sort(dim=0)
        scores_buf = scores_buf[sort_indices]
        indices_buf = indices_buf[sort_indices]
        beams_buf = beams_buf[sort_indices]
        constraint_states = [constraint_states[i] for i in sort_indices]

        # STEP 8: Truncate to the candidates size!
        scores_buf = scores_buf[:self.num_cands]
        indices_buf = indices_buf[:self.num_cands]
        beams_buf = beams_buf[:self.num_cands]

        return scores_buf, indices_buf, beams_buf, constraint_states


class LengthConstrainedBeamSearch(Search):
    r"""
    Length constrained beam search for generation.

    step 1. Build length constraints in model's log-probability. If `min_lens` > `step`,
        set eos token's score to `-math.inf`, so the generation will not be easily stopped.
        Otherwise, `max_lens` <= `step`, set eos token's score to `0`, so the generation will
        be easily stopped.
    step 2. Using beam search to generate the hypothesis tokens with scores.
    """

    def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
        super().__init__(tgt_dict)
        self.min_len_a = min_len_a
        self.min_len_b = min_len_b
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.beam = BeamSearch(tgt_dict)
        self.needs_src_lengths = True

    def step(
        self,
        step: int,
        lprobs,
        scores,
        prev_output_tokens: Optional[Tensor] = None,
        original_batch_idxs: Optional[Tensor] = None,
    ):
        r"""
        Take a single search step.

        Args:
            step (`int`):  Current step, start with 0.
            lprobs (`Tensor` with size `(bsz, input_beam_size, vocab_size)`):
                the model's log-probabilities over the vocabulary at the current step.
            scores (`Tensor` with size `(bsz, input_beam_size, step - 1)`):
                Previous sampling scores for each beam.
            prev_output_tokens (`Tensor`, **optional**. default to `None`):
                Previous output tokens, no usage in this function, will be deprecated in next version.
            original_batch_idxs (`Tensor`, **optional**, default to `None`):
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Returns: A tuple of (scores_buf, indices_buf, beams_buf), where:
            scores_buf (`Tensor` with size `(bsz, output_beam_size)`):
                The model's log-probabilities over the elements selected to sample from.
                `output_beam_size` is the minor of `2 * input_beam_size` and `vocab_size - 1`.
                which cumulates the score before.
            indices_buf (`Tensor` with size `(bsz, output_beam_size)`):
                The indices of chosen elements.
            beams_buf (`Tensor` with size `(bsz, output_beam_size)`):
                The indices of each beam.
        """
        min_lens = self.min_len_a * self.src_lengths + self.min_len_b
        max_lens = self.max_len_a * self.src_lengths + self.max_len_b
        # There seems to be a bug here. Should be right like:
        # lprobs[[step < min_lens] * len(lprobs), :, self.eos] = -math.inf
        # lprobs[[step >= max_lens] * len(lprobs), :, self.eos] = 0
        lprobs[step < min_lens, :, self.eos] = -math.inf
        lprobs[step >= max_lens, :, self.eos] = 0
        return self.beam.step(step, lprobs, scores)


class DiverseBeamSearch(Search):
    """Diverse Beam Search.

    See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
    Models" for details.

    We only implement the Hamming Diversity penalty here, which performed best
    in the original paper.
    """

    def __init__(self, tgt_dict, num_groups, diversity_strength):
        super().__init__(tgt_dict)
        self.num_groups = num_groups
        self.diversity_strength = -diversity_strength
        self.beam = BeamSearch(tgt_dict)

    @torch.jit.export
    def step(
        self,
        step: int,
        lprobs,
        scores,
        prev_output_tokens: Optional[Tensor] = None,
        original_batch_idxs: Optional[Tensor] = None,
    ):
        r"""
        Take a single search step.

        Args:
            step (`int`):  Current step, start with 0.
            lprobs (`Tensor` with size `(bsz, input_beam_size, vocab_size)`):
                the model's log-probabilities over the vocabulary at the current step.
            scores (`Tensor` with size `(bsz, input_beam_size, step - 1)`):
                Previous sampling scores for each beam.
            prev_output_tokens (`Tensor`, **optional**. default to `None`):
                Previous output tokens, no usage in this function, will be deprecated in next version.
            original_batch_idxs (`Tensor`, **optional**, default to `None`):
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Returns: A tuple of (scores_buf, indices_buf, beams_buf), where:
            scores_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The model's log-probabilities over the elements selected to sample from,
                which cumulates the score before.
            indices_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The indices of chosen elements.
            beams_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The indices of each beam.
        """
        bsz, beam_size, vocab_size = lprobs.size()
        if beam_size % self.num_groups != 0:
            raise ValueError(
                'DiverseBeamSearch requires --beam to be divisible by the number of groups'
            )

        # initialize diversity penalty
        diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs)

        scores_G, indices_G, beams_G = [], [], []
        for g in range(self.num_groups):
            lprobs_g = lprobs[:, g::self.num_groups, :]
            scores_g = scores[:, g::self.num_groups, :] if step > 0 else None

            # apply diversity penalty
            if g > 0:
                lprobs_g = torch.add(
                    lprobs_g,
                    other=diversity_buf.unsqueeze(1),
                    alpha=self.diversity_strength,
                )
            else:
                lprobs_g = lprobs_g.contiguous()

            scores_buf, indices_buf, beams_buf = self.beam.step(
                step, lprobs_g, scores_g)
            beams_buf.mul_(self.num_groups).add_(g)

            scores_G.append(scores_buf.clone())
            indices_G.append(indices_buf.clone())
            beams_G.append(beams_buf.clone())

            # update diversity penalty
            diversity_buf.scatter_add_(
                1, indices_buf,
                torch.ones(indices_buf.size()).to(diversity_buf))

        # interleave results from different groups
        scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1)
        indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1)
        beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1)
        return scores_buf, indices_buf, beams_buf


class Sampling(Search):
    r"""
    Sampling search for generation.

    1. Calculate the sample set.
        1.1 If `sampling_topk` is not None, chose the candidates which cumulative sum of model's
            log-probability under descending order is less than `sampling_topk`.
        1.2 If `sampling_topp` is not None, chose the top k candidates by model's log-probability under
            the descending order.
        1.3 Chose the whole input set as sampling set.
    2. Using multinomial sample strategy to sample candidates from sample set as hypothesis.
    3. Modify hypothesis score, relative indices, beam indices for the final result.

    Attributes:
        sampling_topk (`int`, **optional**, default to `-1`):
            The value of k in the sampling strategy of top k.
        sampling_topp (`float`, **optional**, default to '-1.0'):
            The value of p The sampling strategy of top p.
    """
    sampling_topk: int
    sampling_topp: float

    def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
        super().__init__(tgt_dict)
        self.sampling_topk = sampling_topk
        self.sampling_topp = sampling_topp

    def _sample_topp(self, lprobs):
        """Sample among the smallest set of elements whose cumulative probability mass exceeds p.

        See `"The Curious Case of Neural Text Degeneration"
        (Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.

        Args:
            lprobs: (bsz x input_beam_size x vocab_size)
                the model's log-probabilities over the vocabulary at the current step

        Return: A tuple of (trimed_probs, truncated_indices) where:
            trimed_probs: (bsz x input_beam_size x ?)
                the model's probabilities over the elements selected to sample from. The
                width of the third dimension is determined by top-P.
            truncated_indices: (bsz x input_beam_size x ?)
                the indices of the chosen elements.
        """
        probs = lprobs.exp_()

        # sort the last dimension (vocab dimension) in descending order
        sorted_probs, sorted_indices = probs.sort(descending=True)

        # compute a mask to indicate the words to be included in the top-P set.
        cumsum_probs = sorted_probs.cumsum(dim=2)
        mask = cumsum_probs.lt(self.sampling_topp)

        # note that mask was computed by 'lt'. One more word needs to be included
        # so that the cumulative probability mass can exceed p.
        cumsum_mask = mask.cumsum(dim=2)
        last_included = cumsum_mask[:, :, -1:]
        last_included.clamp_(0, mask.size()[2] - 1)
        mask = mask.scatter_(2, last_included, 1)

        # truncate unnecessary dims.
        max_dim = last_included.max()
        truncated_mask = mask[:, :, :max_dim + 1]
        truncated_probs = sorted_probs[:, :, :max_dim + 1]
        truncated_indices = sorted_indices[:, :, :max_dim + 1]

        # trim the words that are not in top-P by setting their probabilities
        # to 0, so that they would not be sampled later.
        trim_mask = ~truncated_mask
        trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
        return trimed_probs, truncated_indices

    @torch.jit.export
    def step(
        self,
        step: int,
        lprobs,
        scores,
        prev_output_tokens: Optional[Tensor] = None,
        original_batch_idxs: Optional[Tensor] = None,
    ):
        r"""
        Take a single search step.

        Args:
            step (`int`):  Current step, start with 0.
            lprobs (`Tensor` with size `(bsz, input_beam_size, vocab_size)`):
                the model's log-probabilities over the vocabulary at the current step.
            scores (`Tensor` with size `(bsz, input_beam_size, step - 1)`):
                Previous sampling scores for each beam.
            prev_output_tokens (`Tensor`, **optional**. default to `None`):
                Previous output tokens, no usage in this function, will be deprecated in next version.
            original_batch_idxs (`Tensor`, **optional**, default to `None`):
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Returns: A tuple of (scores_buf, indices_buf, beams_buf), where:
            scores_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The model's log-probabilities over the elements selected to sample from.
                which cumulates the score before.
            indices_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The indices of chosen elements.
            beams_buf (`Tensor` with size `(bsz, input_beam_size)`):
                The indices of each beam.
        """
        bsz, beam_size, vocab_size = lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()

        if self.sampling_topp > 0:
            # only sample from the smallest set of words whose cumulative probability mass exceeds p
            probs, top_indices = self._sample_topp(lprobs)
        elif self.sampling_topk > 0:
            # only sample from top-k candidates
            lprobs, top_indices = lprobs.topk(self.sampling_topk)
            probs = lprobs.exp_()
        else:
            probs = lprobs.exp_()

            # dummy data to be consistent with true branch for type check
            top_indices = torch.empty(0).to(probs)
        # sample
        if step == 0:
            indices_buf = torch.multinomial(
                probs.view(bsz, -1),
                beam_size,
                replacement=True,
            ).view(bsz, beam_size)
        else:
            indices_buf = torch.multinomial(
                probs.view(bsz * beam_size, -1),
                1,
                replacement=True,
            ).view(bsz, beam_size)

        if step == 0:
            # expand to beam size
            probs = probs.expand(bsz, beam_size, -1)

        # gather scores
        scores_buf = torch.gather(
            probs, dim=2, index=indices_buf.unsqueeze(-1))
        scores_buf = scores_buf.log_().view(bsz, -1)

        # remap indices if using top-k or top-P sampling
        if self.sampling_topk > 0 or self.sampling_topp > 0:
            indices_buf = torch.gather(
                top_indices.expand(bsz, beam_size, -1),
                dim=2,
                index=indices_buf.unsqueeze(-1),
            ).squeeze(2)

        if step == 0:
            beams_buf = indices_buf.new_zeros(bsz, beam_size)
        else:
            beams_buf = torch.arange(0,
                                     beam_size).to(indices_buf).repeat(bsz, 1)
            # make scores cumulative
            scores_buf.add_(
                torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf))

        return scores_buf, indices_buf, beams_buf


class DiverseSiblingsSearch(Search):
    """
    Beam search with diverse siblings.

    See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details.
    https://arxiv.org/abs/1611.08562

    1/ Calculate hypotheses for each beam
    2/ Intra-sibling ordering
    3/ Rewrite scores
    4/ Choose top K hypotheses

    if diversity_rate == 0 is equivalent to BeamSearch
    """

    def __init__(self, tgt_dict, diversity_rate):
        super().__init__(tgt_dict)
        self.diversity_rate = diversity_rate
        self.beam = BeamSearch(tgt_dict)

    def step(
        self,
        step: int,
        lprobs,
        scores,
        prev_output_tokens: Optional[Tensor] = None,
        original_batch_idxs: Optional[Tensor] = None,
    ):
        r"""
        Take a single search step.

        Args:
            step (`int`):  Current step, start with 0.
            lprobs (`Tensor` with size `(bsz, input_beam_size, vocab_size)`):
                the model's log-probabilities over the vocabulary at the current step.
            scores (`Tensor` with size `(bsz, input_beam_size, step - 1)`):
                Previous sampling scores for each beam.
            prev_output_tokens (`Tensor`, **optional**. default to `None`):
                Previous output tokens, no usage in this function, will be deprecated in next version.
            original_batch_idxs (`Tensor`, **optional**, default to `None`):
                the tensor with the batch indices, in the range [0, bsz)
                this is useful in case there has been applied a re-ordering
                and we need to know the original indices

        Returns: A tuple of (scores_buf, indices_buf, beams_buf), where:
            final_scores (`Tensor` with size `(bsz, output_beam_size)`):
                The model's log-probabilities over the elements selected to sample from,
                which cumulates the score before. `output_beam_size` is the minor of
                `2 * input_beam_size` and `vocab_size - 1`.
            final_indices (`Tensor` with size `(bsz, output_beam_size)`):
                The indices of chosen elements.
            final_beams (`Tensor` with size `(bsz, ourput_beam_size)`):
                The indices of each beam.
        """
        bsz, beam_size, vocab_size = lprobs.size()
        k = min(
            # Take the best 2 x beam_size predictions. We'll choose the first
            # beam_size of these which don't predict eos to continue with.
            beam_size * 2,
            lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
        )
        s_list: List[Tensor]
        i_list: List[Tensor]
        s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)]
        i_list = [
            torch.LongTensor().to(device=lprobs.device)
            for i in range(beam_size)
        ]
        sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate

        if step == 0:
            return self.beam.step(step, lprobs, scores)
        lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))

        # 1/ Calculate hypotheses for each beam
        for i in range(beam_size):
            torch.topk(
                lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i]))
            i_list[i].fmod_(vocab_size)

            # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores
            s_list[i].sub_(sibling_score)

        # 4/ Choose top K hypotheses
        indices = torch.stack(i_list, dim=1).view(bsz, -1)

        final_scores = torch.empty(0).to(lprobs)
        final_indices = torch.LongTensor().to(device=lprobs.device)
        final_beams = torch.LongTensor().to(device=lprobs.device)
        (final_scores, final_indices) = torch.topk(
            torch.stack(s_list, dim=1).view(bsz, -1),
            k,
        )

        final_beams = final_indices // k

        for i in range(bsz):
            final_indices[i] = indices[i][final_indices[i]]

        return final_scores, final_indices, final_beams
