# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
import logging
import math

import torch
import torch.nn.functional as F
from megatron_util import mpu
from torch import nn

from modelscope.utils.nlp.distributed import (normal_init_method,
                                              scaled_init_method)
from .configuration import PlugNLGConfig, PlugNLUConfig

logger = logging.getLogger(__name__)


def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish}


class BertLayerNorm(nn.Module):

    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = mpu.VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            init_method=normal_init_method(
                mean=0.0, std=config.initializer_range))
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.fp32_layernorm = config.fp32_layernorm
        self.fp32_embedding = config.fp32_embedding
        self.fp32_tokentypes = config.fp32_tokentypes
        self.LayerNorm = BertLayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        if not self.fp32_tokentypes:

            embeddings = words_embeddings + position_embeddings + token_type_embeddings
            if self.fp32_embedding and not self.fp32_layernorm:
                embeddings = embeddings.half()
            previous_type = embeddings.type()
            if self.fp32_layernorm:
                embeddings = embeddings.float()
            embeddings = self.LayerNorm(embeddings)
            if self.fp32_layernorm:
                if self.fp32_embedding:
                    embeddings = embeddings.half()
                else:
                    embeddings = embeddings.type(previous_type)
        else:
            embeddings = words_embeddings.float() + position_embeddings.float(
            ) + token_type_embeddings.float()
            if self.fp32_tokentypes and not self.fp32_layernorm:
                embeddings = embeddings.half()
            previous_type = embeddings.type()
            if self.fp32_layernorm:
                embeddings = embeddings.float()
            embeddings = self.LayerNorm(embeddings)
            if self.fp32_layernorm:
                if self.fp32_tokentypes:
                    embeddings = embeddings.half()
                else:
                    embeddings = embeddings.type(previous_type)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfOutput(nn.Module):

    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        if hasattr(config, 'deep_init') and config.deep_init:
            init_method = scaled_init_method(
                mean=0.0,
                std=config.initializer_range,
                num_layers=config.num_hidden_layers)
        else:
            init_method = normal_init_method(
                mean=0.0, std=config.initializer_range)
        self.dense = mpu.RowParallelLinear(
            input_size=config.hidden_size,
            output_size=config.hidden_size,
            bias=True,
            input_is_parallel=True,
            stride=1,
            init_method=init_method)
        self.fp32_layernorm = config.fp32_layernorm
        if not config.pre_ln:
            self.LayerNorm = BertLayerNorm(
                config.hidden_size, eps=config.layernorm_epsilon)
        else:
            self.LayerNorm = None
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states,
        input_tensor,
    ):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        ln_input = hidden_states + input_tensor
        if self.LayerNorm is not None:
            previous_type = ln_input.type()
            if self.fp32_layernorm:
                ln_input = ln_input.float()
            hidden_states = self.LayerNorm(ln_input)
            if self.fp32_layernorm:
                hidden_states = hidden_states.type(previous_type)
        else:
            hidden_states = ln_input
        return hidden_states


class BertAttention(nn.Module):

    def __init__(self, config):
        super(BertAttention, self).__init__()
        self.fp32_layernorm = config.fp32_layernorm
        if config.pre_ln:
            self.LayerNorm = BertLayerNorm(
                config.hidden_size, eps=config.layernorm_epsilon)
        else:
            self.LayerNorm = None
        self.self = mpu.BertParallelSelfAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            dropout_prob=config.attention_probs_dropout_prob,
            output_parallel=True,
            init_method=normal_init_method(
                mean=0.0, std=config.initializer_range),
            separate=config.attn_separate)
        self.output = BertSelfOutput(config)

    def forward(
        self,
        input_tensor,
        attention_mask,
    ):
        if self.LayerNorm is not None:
            ln_input = input_tensor
            previous_type = input_tensor.type()
            if self.fp32_layernorm:
                ln_input = input_tensor.float()
            ln_output = self.LayerNorm(ln_input)
            if self.fp32_layernorm:
                ln_output = ln_output.type(previous_type)
            self_output = self.self(
                ln_output,
                attention_mask,
            )
        else:
            self_output = self.self(
                input_tensor,
                attention_mask,
            )

        attention_output = self.output(
            self_output,
            input_tensor,
        )
        return attention_output


class BertIntermediate(nn.Module):

    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = mpu.ColumnParallelLinear(
            input_size=config.hidden_size,
            output_size=config.intermediate_size,
            bias=True,
            gather_output=False,
            stride=1,
            init_method=normal_init_method(
                mean=0.0, std=config.initializer_range))
        self.intermediate_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act

    def forward(
        self,
        hidden_states,
    ):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):

    def __init__(self, config):
        super(BertOutput, self).__init__()
        if hasattr(config, 'deep_init') and config.deep_init:
            init_method = scaled_init_method(
                mean=0.0,
                std=config.initializer_range,
                num_layers=config.num_hidden_layers)
        else:
            init_method = normal_init_method(
                mean=0.0, std=config.initializer_range)
        self.dense = mpu.RowParallelLinear(
            input_size=config.intermediate_size,
            output_size=config.hidden_size,
            bias=True,
            input_is_parallel=True,
            stride=1,
            init_method=init_method)
        self.fp32_layernorm = config.fp32_layernorm
        if not config.pre_ln:
            self.LayerNorm = BertLayerNorm(
                config.hidden_size, eps=config.layernorm_epsilon)
        else:
            self.LayerNorm = None
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states,
        input_tensor,
    ):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        ln_input = hidden_states + input_tensor
        if self.LayerNorm is not None:
            previous_type = ln_input.type()
            if self.fp32_layernorm:
                ln_input = ln_input.float()
            hidden_states = self.LayerNorm(ln_input)
            if self.fp32_layernorm:
                hidden_states = hidden_states.type(previous_type)
        else:
            hidden_states = ln_input
        return hidden_states


class BertLayer(nn.Module):

    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
        self.fp32_layernorm = config.fp32_layernorm
        if config.pre_ln:
            self.LayerNorm = BertLayerNorm(
                config.hidden_size, eps=config.layernorm_epsilon)
        else:
            self.LayerNorm = None

    def forward(self, hidden_states, attention_mask):
        attention_output = self.attention(hidden_states, attention_mask)
        if self.LayerNorm is not None:
            ln_input = attention_output
            previous_type = attention_output.type()
            if self.fp32_layernorm:
                ln_input = attention_output.float()
            ln_output = self.LayerNorm(ln_input)
            if self.fp32_layernorm:
                ln_output = ln_output.type(previous_type)
            intermediate_output = self.intermediate(ln_output)
        else:
            intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


class BertEncoder(nn.Module):

    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.fp32_layernorm = config.fp32_layernorm
        if config.pre_ln:
            self.LayerNorm = BertLayerNorm(
                config.hidden_size, eps=config.layernorm_epsilon)
        else:
            self.LayerNorm = None

    def forward(
        self,
        hidden_states,
        attention_mask,
        output_all_encoded_layers=True,
        checkpoint_activations=False,
        detach_index=-1,
    ):
        all_encoder_layers = []

        def custom(start, end):

            def custom_forward(*inputs):
                layers = self.layer[start:end]
                x_ = inputs[0]
                for layer in layers:
                    x_ = layer(x_, inputs[1])
                return x_

            return custom_forward

        if checkpoint_activations:
            layer_idx = 0
            num_layers = len(self.layer)
            chunk_length = 1
            while layer_idx < num_layers:
                hidden_states = mpu.checkpoint(
                    custom(layer_idx, layer_idx + chunk_length), hidden_states,
                    attention_mask * 1)
                if detach_index == layer_idx:
                    hidden_states.detach_()
                layer_idx += chunk_length
            # decoder layers
        else:
            for i, layer_module in enumerate(self.layer):
                hidden_states = layer_module(hidden_states, attention_mask)
                if detach_index == i:
                    hidden_states.detach_()
                if i == len(self.layer) - 1 and self.LayerNorm is not None:
                    previous_type = hidden_states.type()
                    if self.fp32_layernorm:
                        hidden_states = hidden_states.float()
                    hidden_states = self.LayerNorm(hidden_states)
                    if self.fp32_layernorm:
                        hidden_states = hidden_states.type(previous_type)
                if output_all_encoded_layers:
                    all_encoder_layers.append(hidden_states)

        if not output_all_encoded_layers or checkpoint_activations:
            if self.LayerNorm is not None:
                previous_type = hidden_states.type()
                if self.fp32_layernorm:
                    hidden_states = hidden_states.float()
                hidden_states = self.LayerNorm(hidden_states)
                if self.fp32_layernorm:
                    hidden_states = hidden_states.type(previous_type)
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers


class BertPooler(nn.Module):

    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):

    def __init__(self, config):
        super(BertPredictionHeadTransform, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.transform_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act
        self.LayerNorm = BertLayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)
        self.fp32_layernorm = config.fp32_layernorm

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        previous_type = hidden_states.type()
        if self.fp32_layernorm:
            hidden_states = hidden_states.float()
        hidden_states = self.LayerNorm(hidden_states)
        if self.fp32_layernorm:
            hidden_states = hidden_states.type(previous_type)
        return hidden_states


class BertLMPredictionHead(nn.Module):

    def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder_weight = bert_model_embedding_weights
        self.bias = nn.Parameter(
            torch.zeros(bert_model_embedding_weights.size(0)))
        self.bias.model_parallel = True
        self.fp32_embedding = config.fp32_embedding
        self.fp32_layernorm = config.fp32_layernorm

        def convert_to_type(tensor):
            if self.fp32_embedding:
                return tensor.half()
            else:
                return tensor

        self.type_converter = convert_to_type
        self.converted = False

    def forward(self, hidden_states):
        if not self.converted:
            self.converted = True
            if self.fp32_embedding:
                self.transform.half()
                if self.fp32_layernorm:
                    self.transform.LayerNorm.float()
        hidden_states = self.transform(self.type_converter(hidden_states))
        hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
        hidden_states = F.linear(
            self.type_converter(hidden_states),
            self.type_converter(self.decoder_weight),
            self.type_converter(self.bias))
        return hidden_states


class BertPreTrainingHeads(nn.Module):

    def __init__(self, config, bert_model_embedding_weights):
        super(BertPreTrainingHeads, self).__init__()
        self.predictions = BertLMPredictionHead(config,
                                                bert_model_embedding_weights)
        self.seq_relationship = nn.Linear(config.hidden_size, 3)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        for p in self.seq_relationship.parameters():
            if p is None:
                continue
            pooled_output = pooled_output.type_as(p)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class PreTrainedBertModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    """

    def __init__(self, config, *inputs, **kwargs):
        super(PreTrainedBertModel, self).__init__()
        if not isinstance(config, PlugNLUConfig) and not isinstance(
                config, PlugNLGConfig):
            raise ValueError(
                'Parameter config in `{}(config)` should be an instance of class `BertConfig`. '
                'To create a model from a Google pretrained model use '
                '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
                    self.__class__.__name__, self.__class__.__name__))
        self.config = config

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class BertModel(PreTrainedBertModel):
    """BERT model ("Bidirectional Embedding Representations from a Transformer").

    Params:
        config: a BertConfig class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as
            described below. Default: `True`.

    Outputs: Tuple of (encoded_layers, pooled_output)
        `encoded_layers`: controlled by `output_all_encoded_layers` argument:
            - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
                of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
                encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
            - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
                to the last attention block of shape [batch_size, sequence_length, hidden_size],
        `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
            classifier pretrained on top of the hidden state associated to the first character of the
            input (`CLF`) to train on the Next-Sentence task (see BERT's paper).

    Examples:
        >>> # Already been converted into WordPiece token ids
        >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
        >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
        >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

        >>> config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        >>>     num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

        >>> model = modeling.BertModel(config=config)
        >>> all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
    """

    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(
        self,
        input_ids,
        token_type_ids=None,
        attention_mask=None,
        output_all_encoded_layers=True,
        checkpoint_activations=False,
        detach_index=-1,
    ):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.encoder.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(
            embedding_output,
            extended_attention_mask,
            output_all_encoded_layers=output_all_encoded_layers,
            checkpoint_activations=checkpoint_activations,
            detach_index=detach_index)
        sequence_output = encoded_layers[-1]
        for p in self.pooler.parameters():
            if p is None:
                continue
            sequence_output = sequence_output.type_as(p)
            break

        pooled_output = sequence_output[:, 0]
        if not output_all_encoded_layers or checkpoint_activations:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


class DecodeLayer(nn.Module):

    def __init__(self, config):
        super(DecodeLayer, self).__init__()
        init_method = normal_init_method(
            mean=0.0, std=config.initializer_range)
        output_layer_init_method = scaled_init_method(
            mean=0.0,
            std=config.initializer_range,
            num_layers=config.num_hidden_layers)

        self.attention = mpu.GPT2ParallelSelfAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            attention_dropout_prob=config.attention_probs_dropout_prob,
            output_dropout_prob=config.hidden_dropout_prob,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
        )

        self.cross_attention = mpu.PalmParallelCrossAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            attention_dropout_prob=config.attention_probs_dropout_prob,
            output_dropout_prob=config.hidden_dropout_prob,
            init_method=init_method,
            attn_separate=False,
            output_layer_init_method=output_layer_init_method,
        )

        self.input_layernorm = BertLayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)
        self.post_attention_layernorm = BertLayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)
        self.post_cross_attention_layernorm = BertLayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)

        self.intermediate = mpu.ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            gather_output=False,
            init_method=init_method,
        )
        self.intermediate_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act
        self.output = mpu.RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
        )

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.fp32_layernorm = config.fp32_layernorm

        def convert_to_type(tensor):
            if self.fp32_layernorm:
                return tensor.float()
            else:
                return tensor

        self.type_converter = convert_to_type

    # def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
    def forward(self,
                hidden_states,
                enc_hidden_states,
                enc_attn_mask,
                dec_attn_mask,
                is_infer=False):
        residual = hidden_states
        previous_type = hidden_states.type()
        hidden_states = self.input_layernorm(
            self.type_converter(hidden_states))
        if self.fp32_layernorm:
            hidden_states = hidden_states.type(previous_type)
        hidden_states = self.attention(
            hidden_states, dec_attn_mask, is_infer=is_infer)

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(
            self.type_converter(hidden_states))
        if self.fp32_layernorm:
            hidden_states = hidden_states.type(previous_type)
        hidden_states = self.cross_attention(hidden_states, enc_hidden_states,
                                             enc_attn_mask)
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_cross_attention_layernorm(
            self.type_converter(hidden_states))
        if self.fp32_layernorm:
            hidden_states = hidden_states.type(previous_type)
        hidden_states = self.intermediate(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)

        hidden_states = self.output(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class BertDecoder(nn.Module):

    def __init__(self, config):
        super(BertDecoder, self).__init__()
        self.layer = nn.ModuleList(
            [DecodeLayer(config) for _ in range(config.dec_hidden_layers)])

        self.final_layernorm = BertLayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)
        self.fp32_layernorm = config.fp32_layernorm

    def forward(self,
                hidden_states,
                enc_hidden_states,
                enc_attn_mask,
                dec_attn_mask,
                checkpoint_activations=False,
                output_all_encoded_layers=False,
                is_infer=False):

        def custom(start, end):

            def custom_forward(*inputs):
                layers = self.layer[start:end]
                x_ = inputs[0]
                for layer in layers:
                    x_ = layer(
                        x_,
                        inputs[1],
                        inputs[2],
                        dec_attn_mask * 1,
                        is_infer=is_infer)
                return x_

            return custom_forward

        pre_enc_hidden = enc_hidden_states.data
        if checkpoint_activations:
            layer_idx = 0
            num_layers = len(self.layer)
            chunk_length = 1
            while layer_idx < num_layers:
                hidden_states = mpu.checkpoint(
                    custom(layer_idx, layer_idx + chunk_length), hidden_states,
                    enc_hidden_states, enc_attn_mask * 1)
                enc_hidden_states.data = pre_enc_hidden
                layer_idx += chunk_length
        else:
            for i, layer_module in enumerate(self.layer):
                hidden_states = layer_module(
                    hidden_states,
                    enc_hidden_states,
                    enc_attn_mask,
                    dec_attn_mask,
                    is_infer=is_infer)

        previous_type = hidden_states.type()
        if self.fp32_layernorm:
            hidden_states = hidden_states.float()
        hidden_states = self.final_layernorm(hidden_states)
        if self.fp32_layernorm:
            hidden_states = hidden_states.type(previous_type)

        return [hidden_states]


class DecodeModel(PreTrainedBertModel):

    def __init__(self, config):
        super(DecodeModel, self).__init__(config)
        self.decoder = BertDecoder(config)
        self.apply(self.init_bert_weights)

    def forward(self,
                embeddings,
                sequence_output,
                decode_input_ids,
                position_ids=None,
                enc_attn_mask=None,
                dec_attn_mask=None,
                checkpoint_activations=False,
                is_infer=False):
        extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.decoder.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = embeddings(decode_input_ids)
        sequence_output = self.decoder(
            embedding_output,
            sequence_output,
            extended_attention_mask,
            dec_attn_mask,
            checkpoint_activations=False,
            is_infer=is_infer)
        return sequence_output[-1]


class PalmForPreTraining(PreTrainedBertModel):

    def __init__(self, config):
        super(PalmForPreTraining, self).__init__(config)
        self.bert = BertModel(config)
        self.cls = BertPreTrainingHeads(
            config, self.bert.embeddings.word_embeddings.weight)
        self.decoder = DecodeModel(config)
        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                decode_input_ids=None,
                position_ids=None,
                decode_attention_mask=None,
                lm_labels=None,
                checkpoint_activations=False,
                is_infer=False,
                sequence_output=None,
                parallel_output=True):
        if sequence_output is None:
            sequence_output, pooled_output = self.bert(
                input_ids,
                token_type_ids,
                attention_mask,
                output_all_encoded_layers=False,
                checkpoint_activations=checkpoint_activations)
            prediction_scores, seq_relationship_score = self.cls(
                sequence_output, pooled_output)
        else:
            prediction_scores = None
            sequence_output = sequence_output.to(
                dtype=next(self.decoder.parameters()).dtype)
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        decode_output = self.decoder(
            self.bert.embeddings,
            sequence_output,
            decode_input_ids,
            position_ids,
            attention_mask,
            decode_attention_mask,
            checkpoint_activations=checkpoint_activations,
            is_infer=is_infer)

        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            decode_output)

        logits_parallel = F.linear(transformer_output_parallel,
                                   self.bert.embeddings.word_embeddings.weight)

        if parallel_output:
            return prediction_scores, logits_parallel
        if is_infer:
            return prediction_scores, mpu.gather_from_model_parallel_region(
                logits_parallel), sequence_output
        return prediction_scores, mpu.gather_from_model_parallel_region(
            logits_parallel)


class PlugModel(torch.nn.Module):
    """
    The bare Plug Model transformer outputting raw hidden-states without any specific head on top.
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
    Parameters:
        config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights.
    Examples:

    >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given
    >>> # here only initializes a slice of the model on a single GPU.
    >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model.
    >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel

    >>> # Initializing a Plug configuration
    >>> configuration = PlugNLGConfig()

    >>> # Initializing a model from the configuration
    >>> model = PlugModel(configuration)
    """

    def __init__(self, config):
        super(PlugModel, self).__init__()
        self.config = config
        self.model = PalmForPreTraining(self.config)

    def forward(self,
                input_tokens,
                token_type_ids=None,
                attention_mask=None,
                target_tokens=None,
                position_ids=None,
                decode_attention_mask=None,
                checkpoint_activations=False,
                is_infer=False,
                sequence_output=None,
                parallel_output=True):
        """
        Parameters:
            input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`):
                `input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
                Indices can be obtained using transformers [`BertTokenizer`]. See
                [`TextGenerationPreprocessor.__call__`] for details.
            token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to
            None):
               Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
               1]`:

               - 0 corresponds to a *sentence A* token,
               - 1 corresponds to a *sentence B* token.

            attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

            target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
                Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model,
                i.e. you can set `target_tokens = input_tokens` Indices are selected in
                `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only
                computed for labels in `[0, ..., config.vocab_size]`

            position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
                Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
                `[0, config.max_position_embeddings - 1]`.

            decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults
            to None):
                Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in
                `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

            checkpoint_activations (`boolean`, *optional*, defaults to `False`):
                Whether gradient checkpointing is activated for this model or not.
            is_infer (`boolean`, *optional*, defaults to `False`):
                Whether or not to perform single inference.
            sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*,
            defaults to None):
                Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the
                model. A single forward() call can produce one single token. To generate the current token, the
                sequence_output generated by the `forward()` of the previous token is required.
            parallel_output (`boolean`, *optional*, defaults to `True`):
                To parallel return output, or gather it before return.


        """
        return self.model(
            input_tokens,
            token_type_ids,
            attention_mask,
            target_tokens,
            position_ids,
            decode_attention_mask,
            checkpoint_activations=checkpoint_activations,
            is_infer=is_infer,
            sequence_output=sequence_output,
            parallel_output=parallel_output)

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return self.model.state_dict(
            destination=destination, prefix=prefix, keep_vars=keep_vars)

    def load_state_dict(self, state_dict, strict=True):
        return self.model.load_state_dict(state_dict, strict=strict)
