# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict, List, Optional, Union

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
from modelscope.models.base import Model
from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
                                       ThirdParty)
from modelscope.utils.hub import read_config
from modelscope.utils.import_utils import is_transformers_available
from modelscope.utils.logger import get_logger
from modelscope.utils.plugins import (register_modelhub_repo,
                                      register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
from modelscope.utils.task_utils import is_embedding_task
from .base import Pipeline
from .util import is_official_hub_path

PIPELINES = Registry('pipelines')
logger = get_logger()


def normalize_model_input(model,
                          model_revision,
                          third_party=None,
                          ignore_file_pattern=None):
    """ normalize the input model, to ensure that a model str is a valid local path: in other words,
    for model represented by a model id, the model shall be downloaded locally
    """
    if isinstance(model, str) and is_official_hub_path(model, model_revision):
        # skip revision download if model is a local directory
        if not os.path.exists(model):
            # note that if there is already a local copy, snapshot_download will check and skip downloading
            user_agent = {Invoke.KEY: Invoke.PIPELINE}
            if third_party is not None:
                user_agent[ThirdParty.KEY] = third_party
            model = snapshot_download(
                model,
                revision=model_revision,
                user_agent=user_agent,
                ignore_file_pattern=ignore_file_pattern)
    elif isinstance(model, list) and isinstance(model[0], str):
        for idx in range(len(model)):
            if is_official_hub_path(
                    model[idx],
                    model_revision) and not os.path.exists(model[idx]):
                user_agent = {Invoke.KEY: Invoke.PIPELINE}
                if third_party is not None:
                    user_agent[ThirdParty.KEY] = third_party
                model[idx] = snapshot_download(
                    model[idx], revision=model_revision, user_agent=user_agent)
    return model


def build_pipeline(cfg: ConfigDict,
                   task_name: str = None,
                   default_args: dict = None):
    """ build pipeline given model config dict.

    Args:
        cfg (:obj:`ConfigDict`): config dict for model object.
        task_name (str, optional):  task name, refer to
            :obj:`Tasks` for more details.
        default_args (dict, optional): Default initialization arguments.
    """
    return build_from_cfg(
        cfg, PIPELINES, group_key=task_name, default_args=default_args)


def pipeline(task: str = None,
             model: Union[str, List[str], Model, List[Model]] = None,
             preprocessor=None,
             config_file: str = None,
             pipeline_name: str = None,
             framework: str = None,
             device: str = None,
             model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
             ignore_file_pattern: List[str] = None,
             **kwargs) -> Pipeline:
    """ Factory method to build an obj:`Pipeline`.


    Args:
        task (str): Task name defining which pipeline will be returned.
        model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object.
        preprocessor: preprocessor object.
        config_file (str, optional): path to config file.
        pipeline_name (str, optional): pipeline class name or alias name.
        framework (str, optional): framework type.
        model_revision: revision of model(s) if getting from model hub, for multiple models, expecting
        all models to have the same revision
        device (str, optional): whether to use gpu or cpu is used to do inference.
        ignore_file_pattern(`str` or `List`, *optional*, default to `None`):
            Any file pattern to be ignored in downloading, like exact file names or file extensions.

    Return:
        pipeline (obj:`Pipeline`): pipeline object for certain task.

    Examples:
        >>> # Using default model for a task
        >>> p = pipeline('image-classification')
        >>> # Using pipeline with a model name
        >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased')
        >>> # Using pipeline with a model object
        >>> resnet = Model.from_pretrained('Resnet')
        >>> p = pipeline('image-classification', model=resnet)
        >>> # Using pipeline with a list of model names
        >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2'])
    """
    if task is None and pipeline_name is None:
        raise ValueError('task or pipeline_name is required')

    pipeline_props = None
    if pipeline_name is None:
        # get default pipeline for this task
        if isinstance(model, str) \
           or (isinstance(model, list) and isinstance(model[0], str)):
            if is_official_hub_path(model, revision=model_revision):
                # read config file from hub and parse
                cfg = read_config(
                    model, revision=model_revision) if isinstance(
                        model, str) else read_config(
                            model[0], revision=model_revision)
                if cfg:
                    pipeline_name = cfg.safe_get('pipeline',
                                                 {}).get('type', None)

                if pipeline_name is None:
                    prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
                    # if not specified in both args and configuration.json, prefer llm pipeline for aforementioned tasks
                    if task is not None and task.lower() in [
                            Tasks.text_generation, Tasks.chat
                    ]:
                        if prefer_llm_pipeline is None:
                            prefer_llm_pipeline = True
                    # for llm pipeline, if llm_framework is not specified, default to swift instead
                    # TODO: port the swift infer based on transformer into ModelScope
                    if prefer_llm_pipeline:
                        if kwargs.get('llm_framework') is None:
                            kwargs['llm_framework'] = 'swift'
                        pipeline_name = external_engine_for_llm_checker(
                            model, model_revision, kwargs)

                if pipeline_name is None or pipeline_name != 'llm':
                    third_party = kwargs.get(ThirdParty.KEY)
                    if third_party is not None:
                        kwargs.pop(ThirdParty.KEY)

                    model = normalize_model_input(
                        model,
                        model_revision,
                        third_party=third_party,
                        ignore_file_pattern=ignore_file_pattern)

                    register_plugins_repo(cfg.safe_get('plugins'))
                    register_modelhub_repo(model,
                                           cfg.get('allow_remote', False))

                if pipeline_name:
                    pipeline_props = {'type': pipeline_name}
                else:
                    try:
                        check_config(cfg)
                        pipeline_props = cfg.pipeline
                    except AssertionError as e:
                        logger.info(str(e))

        elif model is not None:
            # get pipeline info from Model object
            first_model = model[0] if isinstance(model, list) else model
            if not hasattr(first_model, 'pipeline'):
                # model is instantiated by user, we should parse config again
                cfg = read_config(first_model.model_dir)
                try:
                    check_config(cfg)
                    first_model.pipeline = cfg.pipeline
                except AssertionError as e:
                    logger.info(str(e))
            if first_model.__dict__.get('pipeline'):
                pipeline_props = first_model.pipeline
        else:
            pipeline_name, default_model_repo = get_default_pipeline_info(task)
            model = normalize_model_input(default_model_repo, model_revision)
            pipeline_props = {'type': pipeline_name}
    else:
        pipeline_props = {'type': pipeline_name}

    if not pipeline_props and is_embedding_task(task):
        try:
            from modelscope.utils.hf_util import sentence_transformers_pipeline
            return sentence_transformers_pipeline(model=model, **kwargs)
        except Exception:
            logger.exception(
                'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
                'sentence_transformers, but that also failed.')
            raise

    if not pipeline_props and is_transformers_available():
        try:
            from modelscope.utils.hf_util import hf_pipeline
            return hf_pipeline(
                task=task,
                model=model,
                framework=framework,
                device=device,
                **kwargs)
        except Exception as e:
            logger.error(
                'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,'
                ' but that also failed.')
            raise e

    if not device:
        device = 'gpu'
    pipeline_props['model'] = model
    pipeline_props['device'] = device
    cfg = ConfigDict(pipeline_props)

    clear_llm_info(kwargs, pipeline_name)
    if kwargs:
        cfg.update(kwargs)

    if preprocessor is not None:
        cfg.preprocessor = preprocessor

    return build_pipeline(cfg, task_name=task)


def add_default_pipeline_info(task: str,
                              model_name: str,
                              modelhub_name: str = None,
                              overwrite: bool = False):
    """ Add default model for a task.

    Args:
        task (str): task name.
        model_name (str): model_name.
        modelhub_name (str): name for default modelhub.
        overwrite (bool): overwrite default info.
    """
    if not overwrite:
        assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
            f'task {task} already has default model.'

    DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)


def get_default_pipeline_info(task):
    """ Get default info for certain task.

    Args:
        task (str): task name.

    Return:
        A tuple: first element is pipeline name(model_name), second element
            is modelhub name.
    """

    if task not in DEFAULT_MODEL_FOR_PIPELINE:
        # support pipeline which does not register default model
        pipeline_name = list(PIPELINES.modules[task].keys())[0]
        default_model = None
    else:
        pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
    return pipeline_name, default_model


def external_engine_for_llm_checker(model: Union[str, List[str], Model,
                                                 List[Model]],
                                    revision: Optional[str],
                                    kwargs: Dict[str, Any]) -> Optional[str]:
    from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
    from ..hub.check_model import get_model_id_from_cache
    if isinstance(model, list):
        model = model[0]
    if not isinstance(model, str):
        model = model.model_dir

    llm_framework = kwargs.get('llm_framework', '')
    if llm_framework == 'swift':
        from swift.llm import get_model_info_meta
        # check if swift supports
        if os.path.exists(model):
            model_id = get_model_id_from_cache(model)
        else:
            model_id = model

        try:
            info = get_model_info_meta(model_id)
            model_type = info[0].model_type
        except Exception as e:
            logger.warning(f'Cannot using llm_framework with {model_id}, '
                           f'ignoring llm_framework={llm_framework} : {e}')
            model_type = None
        if model_type:
            return 'llm'

    model_type = ModelTypeHelper.get(
        model, revision, with_adapter=True, split='-', use_cache=True)
    if LLMAdapterRegistry.contains(model_type):
        return 'llm'


def clear_llm_info(kwargs: Dict, pipeline_name: str):
    from modelscope.utils.model_type_helper import ModelTypeHelper

    kwargs.pop('external_engine_for_llm', None)
    if pipeline_name != 'llm':
        kwargs.pop('llm_framework', None)
    ModelTypeHelper.clear_cache()
