# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import ast
import functools
import importlib
import inspect
import logging
import os
import os.path as osp
import sys
from collections import OrderedDict
from importlib import import_module
from itertools import chain
from pathlib import Path
from types import ModuleType
from typing import Any

from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY,
                                        load_index)
from modelscope.utils.error import *  # noqa
from modelscope.utils.logger import get_logger

if sys.version_info < (3, 8):
    import importlib_metadata
else:
    import importlib.metadata as importlib_metadata

logger = get_logger(log_level=logging.WARNING)


def import_modules_from_file(py_file: str):
    """ Import module from a certrain file

    Args:
        py_file: path to a python file to be imported

    Return:

    """
    dirname, basefile = os.path.split(py_file)
    if dirname == '':
        dirname = Path.cwd()
    module_name = osp.splitext(basefile)[0]
    sys.path.insert(0, dirname)
    validate_py_syntax(py_file)
    mod = import_module(module_name)
    sys.path.pop(0)
    return module_name, mod


def is_method_overridden(method, base_class, derived_class):
    """Check if a method of base class is overridden in derived class.

    Args:
        method (str): the method name to check.
        base_class (type): the class of the base class.
        derived_class (type | Any): the class or instance of the derived class.
    """
    assert isinstance(base_class, type), \
        "base_class doesn't accept instance, Please pass class instead."

    if not isinstance(derived_class, type):
        derived_class = derived_class.__class__

    base_method = getattr(base_class, method)
    derived_method = getattr(derived_class, method)
    return derived_method != base_method


def has_method(obj: object, method: str) -> bool:
    """Check whether the object has a method.

    Args:
        method (str): The method name to check.
        obj (object): The object to check.

    Returns:
        bool: True if the object has the method else False.
    """
    return hasattr(obj, method) and callable(getattr(obj, method))


def import_modules(imports, allow_failed_imports=False):
    """Import modules from the given list of strings.

    Args:
        imports (list | str | None): The given module names to be imported.
        allow_failed_imports (bool): If True, the failed imports will return
            None. Otherwise, an ImportError is raise. Default: False.

    Returns:
        list[module] | module | None: The imported modules.

    Examples:
        >>> osp, sys = import_modules(
        ...     ['os.path', 'sys'])
        >>> import os.path as osp_
        >>> import sys as sys_
        >>> assert osp == osp_
        >>> assert sys == sys_
    """
    if not imports:
        return
    single_import = False
    if isinstance(imports, str):
        single_import = True
        imports = [imports]
    if not isinstance(imports, list):
        raise TypeError(
            f'custom_imports must be a list but got type {type(imports)}')
    imported = []
    for imp in imports:
        if not isinstance(imp, str):
            raise TypeError(
                f'{imp} is of type {type(imp)} and cannot be imported.')
        try:
            imported_tmp = import_module(imp)
        except ImportError:
            if allow_failed_imports:
                logger.warning(f'{imp} failed to import and is ignored.')
                imported_tmp = None
            else:
                raise ImportError
        imported.append(imported_tmp)
    if single_import:
        imported = imported[0]
    return imported


def validate_py_syntax(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        # Setting encoding explicitly to resolve coding issue on windows
        content = f.read()
    try:
        ast.parse(content)
    except SyntaxError as e:
        raise SyntaxError('There are syntax errors in config '
                          f'file {filename}: {e}')


# following code borrows implementation from huggingface/transformers
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()

_torch_version = 'N/A'
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
    _torch_available = importlib.util.find_spec('torch') is not None
    if _torch_available:
        try:
            _torch_version = importlib_metadata.version('torch')
            logger.info(f'PyTorch version {_torch_version} Found.')
        except importlib_metadata.PackageNotFoundError:
            _torch_available = False
else:
    logger.info('Disabling PyTorch because USE_TF is set')
    _torch_available = False

_timm_available = importlib.util.find_spec('timm') is not None
try:
    _timm_version = importlib_metadata.version('timm')
    logger.debug(f'Successfully imported timm version {_timm_version}')
except importlib_metadata.PackageNotFoundError:
    _timm_available = False

_tf_version = 'N/A'
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
    _tf_available = importlib.util.find_spec('tensorflow') is not None
    if _tf_available:
        candidates = (
            'tensorflow',
            'tensorflow-cpu',
            'tensorflow-gpu',
            'tf-nightly',
            'tf-nightly-cpu',
            'tf-nightly-gpu',
            'intel-tensorflow',
            'intel-tensorflow-avx512',
            'tensorflow-rocm',
            'tensorflow-macos',
        )
        _tf_version = None
        # For the metadata, we have to look for both tensorflow and tensorflow-cpu
        for pkg in candidates:
            try:
                _tf_version = importlib_metadata.version(pkg)
                break
            except importlib_metadata.PackageNotFoundError:
                pass
        _tf_available = _tf_version is not None
    if _tf_available:
        from packaging import version
        if version.parse(_tf_version) < version.parse('2'):
            pass
        else:
            logger.info(f'TensorFlow version {_tf_version} Found.')
else:
    logger.info('Disabling Tensorflow because USE_TORCH is set')
    _tf_available = False


def is_scipy_available():
    return importlib.util.find_spec('scipy') is not None


def is_sklearn_available():
    if importlib.util.find_spec('sklearn') is None:
        return False
    return is_scipy_available() and importlib.util.find_spec('sklearn.metrics')


def is_sentencepiece_available():
    return importlib.util.find_spec('sentencepiece') is not None


def is_protobuf_available():
    if importlib.util.find_spec('google') is None:
        return False
    return importlib.util.find_spec('google.protobuf') is not None


def is_tokenizers_available():
    return importlib.util.find_spec('tokenizers') is not None


def is_timm_available():
    return _timm_available


def is_torch_available():
    return _torch_available


def is_torch_cuda_available():
    if is_torch_available():
        import torch

        return torch.cuda.is_available()
    else:
        return False


def is_wenetruntime_available():
    return importlib.util.find_spec('wenetruntime') is not None


def is_swift_available():
    return importlib.util.find_spec('swift') is not None


def is_tf_available():
    return _tf_available


def is_opencv_available():
    return importlib.util.find_spec('cv2') is not None


def is_pillow_available():
    return importlib.util.find_spec('PIL.Image') is not None


def _is_package_available_fn(pkg_name):
    return importlib.util.find_spec(pkg_name) is not None


def is_package_available(pkg_name):
    return functools.partial(_is_package_available_fn, pkg_name)


def is_espnet_available(pkg_name):
    return importlib.util.find_spec('espnet2') is not None \
        and importlib.util.find_spec('espnet')


def is_vllm_available():
    return importlib.util.find_spec('vllm') is not None


def is_transformers_available():
    return importlib.util.find_spec('transformers') is not None


def is_diffusers_available():
    return importlib.util.find_spec('diffusers') is not None


def is_tensorrt_llm_available():
    return importlib.util.find_spec('tensorrt_llm') is not None


REQUIREMENTS_MAAPING = OrderedDict([
    ('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
    ('sentencepiece', (is_sentencepiece_available,
                       SENTENCEPIECE_IMPORT_ERROR)),
    ('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
    ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
    ('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
    ('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
    ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
    ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
    ('wenetruntime',
     (is_wenetruntime_available,
      WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
    ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
    ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
    ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
    ('pai-easynlp', (is_package_available('easynlp'), EASYNLP_IMPORT_ERROR)),
    ('espnet2', (is_espnet_available,
                 GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
    ('espnet', (is_espnet_available,
                GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
    ('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)),
    ('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)),
    ('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)),
    ('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)),
    ('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)),
    ('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)),
    ('megatron_util', (is_package_available('megatron_util'),
                       MEGATRON_UTIL_IMPORT_ERROR)),
    ('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
                         TEXT2SQL_LGESQL_IMPORT_ERROR)),
    ('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
    ('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
    ('taming', (is_package_available('taming'), TAMING_IMPORT_ERROR)),
    ('xformers', (is_package_available('xformers'), XFORMERS_IMPORT_ERROR)),
    ('swift', (is_package_available('swift'), SWIFT_IMPORT_ERROR)),
])

SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])


def requires(obj, requirements):
    if not isinstance(requirements, (list, tuple)):
        requirements = [requirements]
    if isinstance(obj, str):
        name = obj
    else:
        name = obj.__name__ if hasattr(obj,
                                       '__name__') else obj.__class__.__name__
    checks = []
    for req in requirements:
        if req == '' or req in SYSTEM_PACKAGE:
            continue
        if req in REQUIREMENTS_MAAPING:
            check = REQUIREMENTS_MAAPING[req]
        else:
            check_fn = is_package_available(req)
            err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req)
            check = (check_fn, err_msg)
        checks.append(check)

    failed = [msg.format(name) for available, msg in checks if not available()]
    if failed:
        raise ImportError(''.join(failed))


def torch_required(func):
    # Chose a different decorator name than in tests so it's clear they are not the same.
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if is_torch_available():
            return func(*args, **kwargs)
        else:
            raise ImportError(f'Method `{func.__name__}` requires PyTorch.')

    return wrapper


def tf_required(func):
    # Chose a different decorator name than in tests so it's clear they are not the same.
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if is_tf_available():
            return func(*args, **kwargs)
        else:
            raise ImportError(f'Method `{func.__name__}` requires TF.')

    return wrapper


class LazyImportModule(ModuleType):
    _AST_INDEX = None

    def __init__(self,
                 name,
                 module_file,
                 import_structure,
                 module_spec=None,
                 extra_objects=None,
                 try_to_pre_import=False,
                 extra_import_func=None):
        super().__init__(name)
        self._modules = set(import_structure.keys())
        self._class_to_module = {}
        for key, values in import_structure.items():
            for value in values:
                self._class_to_module[value] = key
        # Needed for autocompletion in an IDE
        self.__all__ = list(import_structure.keys()) + list(
            chain(*import_structure.values()))
        self.__file__ = module_file
        self.__spec__ = module_spec
        self.__path__ = [os.path.dirname(module_file)]
        self._objects = {} if extra_objects is None else extra_objects
        self._name = name
        self._import_structure = import_structure
        self._extra_import_func = extra_import_func
        if try_to_pre_import:
            self._try_to_import()

    def _try_to_import(self):
        for sub_module in self._class_to_module.keys():
            try:
                getattr(self, sub_module)
            except Exception as e:
                logger.warning(
                    f'pre load module {sub_module} error, please check {e}')

    # Needed for autocompletion in an IDE
    def __dir__(self):
        result = super().__dir__()
        # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
        # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
        for attr in self.__all__:
            if attr not in result:
                result.append(attr)
        return result

    def __getattr__(self, name: str) -> Any:
        if name in self._objects:
            return self._objects[name]
        if name in self._modules:
            value = self._get_module(name)
        elif name in self._class_to_module.keys():
            module = self._get_module(self._class_to_module[name])
            value = getattr(module, name)
        elif self._extra_import_func is not None:
            value = self._extra_import_func(name)
            if value is None:
                raise AttributeError(
                    f'module {self.__name__} has no attribute {name}')
        else:
            raise AttributeError(
                f'module {self.__name__} has no attribute {name}')

        setattr(self, name, value)
        return value

    def _get_module(self, module_name: str):
        try:
            module_name_full = self.__name__ + '.' + module_name
            if not any(
                    module_name_full.startswith(f'modelscope.{prefix}')
                    for prefix in ['hub', 'utils', 'version', 'fileio']):
                # check requirements before module import
                ast_index = self.get_ast_index()
                if module_name_full in ast_index[REQUIREMENT_KEY]:
                    requirements = ast_index[REQUIREMENT_KEY][module_name_full]
                    requires(module_name_full, requirements)
            return importlib.import_module('.' + module_name, self.__name__)
        except Exception as e:
            raise RuntimeError(
                f'Failed to import {self.__name__}.{module_name} because of the following error '
                f'(look up to see its traceback):\n{e}') from e

    def __reduce__(self):
        return self.__class__, (self._name, self.__file__,
                                self._import_structure)

    @staticmethod
    def get_ast_index():
        if LazyImportModule._AST_INDEX is None:
            LazyImportModule._AST_INDEX = load_index()
        return LazyImportModule._AST_INDEX

    @staticmethod
    def import_module(signature):
        """ import a lazy import module using signature

        Args:
            signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
        """
        ast_index = LazyImportModule.get_ast_index()
        if signature in ast_index[INDEX_KEY]:
            mod_index = ast_index[INDEX_KEY][signature]
            module_name = mod_index[MODULE_KEY]
            if module_name in ast_index[REQUIREMENT_KEY]:
                requirements = ast_index[REQUIREMENT_KEY][module_name]
                requires(module_name, requirements)
            importlib.import_module(module_name)
        else:
            logger.warning(f'{signature} not found in ast index file')


def has_attr_in_class(cls, attribute_name) -> bool:
    """
    Determine if attribute in specific class.

    Args:
        cls: target class.
        attribute_name: the attribute name.

    Returns:
        The attribute in the class or not.
    """
    init_method = cls.__init__
    signature = inspect.signature(init_method)

    parameters = signature.parameters
    param_names = list(parameters.keys())

    return attribute_name in param_names
