# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import OrderedDict
from typing import Dict, Generator

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.streaming_output import StreamingOutputMixin

logger = get_logger()

__all__ = ['PolyLMForTextGeneration']


@MODELS.register_module(Tasks.text_generation, module_name=Models.polylm)
class PolyLMForTextGeneration(TorchModel, StreamingOutputMixin):

    def __init__(self, model_dir: str, *args, **kwargs):
        """initialize the text generation model from the `model_dir` path.

        Args:
            model_dir (str): the model path.
        """
        super().__init__(model_dir, *args, **kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir, legacy=False, use_fast=False)

        self.check_trust_remote_code(
            info_str=
            f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
            'that you can trust the external codes.',
            model_dir=model_dir)

        self.model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            device_map='auto',
            trust_remote_code=self.trust_remote_code)
        self.model.eval()

    def forward(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
        """return the result by the model

        Args:
            input (Dict[str, Tensor]): the preprocessed data

        Returns:
            Dict[str, Tensor]: results
        """
        res = self.generate(input, **kwargs)
        return res

    def generate(self, input: Dict[str, Tensor],
                 **kwargs) -> Dict[str, Tensor]:
        device = self.model.device
        inputs = self.tokenizer(input, return_tensors='pt')
        outputs = self.model.generate(
            inputs.input_ids.to(device),
            attention_mask=inputs.attention_mask.to(device),
            **kwargs)
        pred = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return pred
