# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from typing import Optional, Union

import addict
import torch
from torch import nn
from torch.nn import functional as F
from transformers.modeling_utils import PreTrainedModel

from modelscope.outputs import TokenGeneratorOutput
from modelscope.utils.constant import ModelFile
from .configuration import GPT3Config
from .distributed_gpt3 import sample


class GPT3SelfAttention(nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config):
        super().__init__()

        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        # Per attention head
        self.hidden_size_per_attention_head = \
            self.hidden_size // self.num_attention_heads

        self.query_key_value = nn.Linear(self.hidden_size,
                                         3 * self.hidden_size)
        self.softmax = nn.Softmax(dim=-1)
        self.attention_dropout = nn.Dropout(
            config.attention_probs_dropout_prob)

        # Output.
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)

    def _transpose_for_scores(self, tensor):
        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
        size [b, np, s, hn].
        """
        new_tensor_shape = tensor.size()[:-1] + (
            self.num_attention_heads, self.hidden_size_per_attention_head)
        tensor = tensor.view(*new_tensor_shape)
        return tensor.permute(0, 2, 1, 3)

    def _split_tensor_along_last_dim(self,
                                     tensor,
                                     num_partitions,
                                     contiguous_split_chunks=False):
        # Get the size and dimension.
        last_dim = tensor.dim() - 1
        last_dim_size = tensor.size()[last_dim] // num_partitions
        # Split.
        tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
        # Note: torch.split does not create contiguous tensors by default.
        if contiguous_split_chunks:
            return tuple(chunk.contiguous() for chunk in tensor_list)

        return tensor_list

    def forward(self, hidden_states, ltor_mask, is_infer=False):
        # hidden_states: [b, s, h]
        # ltor_mask: [1, 1, s, s]

        # Attention heads. [b, s, hp]
        tgt_len = hidden_states.size(1)
        ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len])
        mixed_x_layer = self.query_key_value(hidden_states)
        (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \
            self._split_tensor_along_last_dim(mixed_x_layer, 3)

        # Reshape and transpose [b, np, s, hn]
        query_layer = self._transpose_for_scores(mixed_query_layer)
        key_layer = self._transpose_for_scores(mixed_key_layer)
        value_layer = self._transpose_for_scores(mixed_value_layer)

        previous_type = value_layer.type()

        # Raw attention scores. [b, np, s, s]
        attention_scores = torch.matmul(query_layer,
                                        key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(
            self.hidden_size_per_attention_head)
        # Apply the left to right attention mask.
        if is_infer:
            src_len = key_layer.size(2)
            ltor_mask = torch.tril(
                torch.ones((1, tgt_len, src_len),
                           device=hidden_states.device)).view(
                               1, 1, tgt_len, src_len).type(previous_type)
        converted_mask = 10000.0 * (1.0 - ltor_mask)
        attention_scores = (torch.mul(attention_scores, ltor_mask)
                            - converted_mask).type(previous_type)

        # Attention probabilities. [b, np, s, s]
        attention_probs = self.softmax(attention_scores)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.attention_dropout(attention_probs)

        # Context layer.
        # [b, np, s, hn]
        context_layer = torch.matmul(attention_probs, value_layer)
        # [b, s, np, hn]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size, )
        # [b, s, hp]
        context_layer = context_layer.view(*new_context_layer_shape)

        # Output. [b, s, h]
        output = self.dense(context_layer)
        output = self.output_dropout(output)

        return output


class GPT3MLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config):
        super().__init__()

        hidden_size = config.hidden_size
        # Project to 4h.
        self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
        self.activation_func = F.gelu
        # Project back to h.
        self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states):

        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        output = self.dropout(output)
        return output


class GPT3TransformerLayer(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(self, config):
        super().__init__()

        # Layernorm on the input data.
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)

        # Self attention.
        self.attention = GPT3SelfAttention(config)

        # Layernorm on the attention output
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
        self.mlp = GPT3MLP(config)

    def forward(self, hidden_states, ltor_mask):
        # hidden_states: [b, s, h]
        # ltor_mask: [1, 1, s, s]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.attention(layernorm_output, ltor_mask)
        # Residual connection.
        layernorm_input = hidden_states + attention_output
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        # MLP.
        mlp_output = self.mlp(layernorm_output)
        # Second residual connection.
        output = layernorm_input + mlp_output

        return output


class GPT3Transformer(nn.Module):
    """Transformer class."""

    def __init__(self, config):
        super().__init__()

        self.input_tensor = None

        # Number of layers.
        self.num_layers = config.num_hidden_layers

        self.layers = torch.nn.ModuleList(
            [GPT3TransformerLayer(config) for _ in range(self.num_layers)])

        # Final layer norm before output.
        self.final_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layernorm_epsilon)

    def _get_layer(self, layer_number):
        return self.layers[layer_number]

    def forward(self, hidden_states, attention_mask):
        # hidden_states: [s, b, h]

        for index in range(self.num_layers):
            layer = self._get_layer(index)
            hidden_states = layer(hidden_states, attention_mask)

        # Final layer norm.
        hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class GPT3TransformerLanguageModel(nn.Module):
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """

    def __init__(self, config):
        super().__init__()

        # Embeddings.
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # Transformer.
        self.transformer = GPT3Transformer(config)

    def forward(self, input_ids, attention_mask, position_ids):
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = words_embeddings + position_embeddings
        transformer_input = self.embedding_dropout(embeddings)
        transformer_output = self.transformer(transformer_input,
                                              attention_mask)

        logits = F.linear(transformer_output, self.word_embeddings.weight)
        return logits


class GPT3Model(PreTrainedModel):

    config_class = GPT3Config

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def __init__(self, config):
        super().__init__(config)
        self.language_model = GPT3TransformerLanguageModel(config)

    def forward(self,
                input_ids,
                attention_mask=None,
                position_ids=None,
                labels=None,
                **kwargs):
        seq_length = input_ids.size(1)
        attention_mask = torch.tril(
            torch.ones((1, 1, seq_length, seq_length),
                       dtype=torch.long,
                       device=input_ids.device))
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        logits = self.language_model(input_ids, attention_mask, position_ids)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.config.vocab_size), labels.view(-1))
        return addict.Dict(loss=loss, logits=logits)

    @classmethod
    def from_pretrained(
            cls, pretrained_model_name_or_path: Optional[Union[str,
                                                               os.PathLike]]):
        config = cls.config_class.from_pretrained(
            pretrained_model_name_or_path)
        model = cls(config)
        state_dict_file = os.path.join(pretrained_model_name_or_path,
                                       ModelFile.TORCH_MODEL_BIN_FILE)
        state_dict = torch.load(state_dict_file)
        if 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        state_dict = {
            k.replace('model.language_model', 'language_model'): v
            for k, v in state_dict.items()
        }
        model.load_state_dict(state_dict)
        return model

    def streaming_generate(self, tokens, temperature=1.0, **kwargs):
        top_k = kwargs.pop('top_k', self.config.top_k)
        top_p = kwargs.pop('top_p', self.config.top_p)
        max_length = kwargs.pop('max_length', tokens.size(1) + 100)

        batch_size = tokens.size(0)
        lengths = kwargs.pop(
            'prompt_length',
            torch.tensor([tokens.size(1)], device=tokens.device))

        min_prompt_length = lengths.min().item()
        max_sequence_length = min(max_length,
                                  self.config.max_position_embeddings)

        # If the context is too big, this happens
        if min_prompt_length >= max_sequence_length:
            raise ValueError('context length too large')

        pad_length = max_sequence_length - tokens.size(1)
        if pad_length > 0:
            pads = torch.zeros(
                batch_size, pad_length, device=tokens.device).long()
            tokens = torch.cat((tokens, pads), dim=-1)

        # Added termination_id to support the case that we want to terminate the
        # generation once that id is generated.
        termination_id = self.config.eod_id

        # Whether we have reached a termination id.
        is_generation_done = torch.zeros(
            batch_size, dtype=torch.uint8, device=tokens.device)

        with torch.no_grad():
            for context_length in range(min_prompt_length,
                                        max_sequence_length):

                # Pick the slice that we need to pass through the network.
                tokens2use = tokens[:, :context_length]

                # logits will be meanigful only in the last pipeline stage.
                logits = self(tokens2use).logits

                # Sample.
                last_token_logits = logits[:, -1, :]
                new_sample = sample(
                    last_token_logits,
                    top_k=top_k,
                    top_p=top_p,
                    temperature=temperature,
                    vocab_size=self.config.vocab_size)

                # If a prompt length is smaller or equal th current context
                # length, it means we have started generating tokens
                started = lengths <= context_length
                # Update the tokens.
                tokens[started, context_length] = new_sample[started]

                yield TokenGeneratorOutput(sequences=tokens[:, :(context_length
                                                                 + 1)])

                done_token = (new_sample == termination_id).byte() & \
                    started.byte()

                is_generation_done = is_generation_done | done_token
                done = torch.all(is_generation_done)

                if done:
                    break

    def generate(self, tokens, temperature=1.0, **kwargs):
        last_output = None
        for output in self.streaming_generate(tokens, temperature, **kwargs):
            last_output = output
        return last_output
