from typing import List, Optional, Union

from modelscope import get_logger
from modelscope.pipelines.accelerate.base import InferFramework
from modelscope.utils.import_utils import is_vllm_available

logger = get_logger()


class Vllm(InferFramework):

    def __init__(self,
                 model_id_or_dir: str,
                 dtype: str = 'auto',
                 quantization: str = None,
                 tensor_parallel_size: int = 1,
                 trust_remote_code: Optional[bool] = None):
        """
        Args:
            dtype: The dtype to use, support `auto`, `float16`, `bfloat16`, `float32`
            quantization: The quantization bit, default None means do not do any quantization.
            tensor_parallel_size: The tensor parallel size.
        """
        super().__init__(model_id_or_dir)
        if not is_vllm_available():
            raise ImportError(
                'Install vllm by `pip install vllm` before using vllm to accelerate inference'
            )

        from vllm import LLM
        if not Vllm.check_gpu_compatibility(8) and (dtype
                                                    in ('bfloat16', 'auto')):
            dtype = 'float16'
        self.model = LLM(
            self.model_dir,
            dtype=dtype,
            quantization=quantization,
            trust_remote_code=trust_remote_code,
            tensor_parallel_size=tensor_parallel_size)

    def __call__(self, prompts: Union[List[str], List[List[int]]],
                 **kwargs) -> List[str]:
        """Generate tokens.
        Args:
            prompts(`Union[List[str], List[List[int]]]`):
                The string batch or the token list batch to input to the model.
            kwargs: Sampling parameters.
        """

        # convert hf generate config to vllm
        do_sample = kwargs.pop('do_sample', None)
        num_beam = kwargs.pop('num_beam', 1)
        max_length = kwargs.pop('max_length', None)
        max_new_tokens = kwargs.pop('max_new_tokens', None)

        # for vllm, default to do_sample/greedy(depends on temperature).
        # for hf, do_sample=false, num_beam=1 -> greedy(default)
        #         do_sample=true, num_beam=1 -> sample
        #         do_sample=false, num_beam>1 -> beam_search
        if not do_sample and num_beam > 1:
            kwargs['use_beam_search'] = True
        if max_length:
            kwargs['max_tokens'] = max_length - len(prompts[0])
        if max_new_tokens:
            kwargs['max_tokens'] = max_new_tokens

        from vllm import SamplingParams
        sampling_params = SamplingParams(**kwargs)
        if isinstance(prompts[0], str):
            return [
                output.outputs[0].text for output in self.model.generate(
                    prompts, sampling_params=sampling_params)
            ]
        else:
            return [
                output.outputs[0].text for output in self.model.generate(
                    prompt_token_ids=prompts, sampling_params=sampling_params)
            ]

    def model_type_supported(self, model_type: str):
        return any([
            model in model_type.lower() for model in [
                'llama',
                'baichuan',
                'internlm',
                'mistral',
                'aquila',
                'bloom',
                'falcon',
                'gpt',
                'mpt',
                'opt',
                'qwen',
                'aquila',
            ]
        ])
