# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import importlib
import inspect
import os
import re
import sys
from asyncio import Future
from functools import partial
from pathlib import Path
from types import MethodType
from typing import BinaryIO, Dict, Iterable, List, Optional, Union

from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
from modelscope.utils.repo_utils import (CommitInfo, CommitOperation,
                                         CommitOperationAdd)

ignore_file_pattern = [
    r'\w+\.bin',
    r'\w+\.safetensors',
    r'\w+\.pth',
    r'\w+\.pt',
    r'\w+\.h5',
    r'\w+\.ckpt',
    r'\w+\.zip',
    r'\w+\.onnx',
    r'\w+\.tar',
    r'\w+\.gz',
]


def get_all_imported_modules():
    """Find all modules in transformers/peft/diffusers"""
    all_imported_modules = []
    transformers_include_names = [
        'Auto.*',
        'T5.*',
        'BitsAndBytesConfig',
        'GenerationConfig',
        'Awq.*',
        'GPTQ.*',
        'BatchFeature',
        'Qwen.*',
        'Llama.*',
        'Intern.*',
        'Deepseek.*',
        'PretrainedConfig',
        'PreTrainedTokenizer',
        'PreTrainedModel',
        'PreTrainedTokenizerFast',
    ]
    peft_include_names = ['.*PeftModel.*', '.*Config']
    diffusers_include_names = [
        '^(?!TF|Flax).*Pipeline$', '^(?!TF|Flax).*Autoencoder.*',
        '^(?!TF|Flax).*Model$', '^(?!TF|Flax).*Adapter$', 'ImageProjection',
        '^(?!TF|Flax).*UNet$', '^(?!TF|Flax).*Scheduler$'
    ]
    if importlib.util.find_spec('transformers') is not None:
        import transformers
        lazy_module = sys.modules['transformers']
        _import_structure = lazy_module._import_structure
        for key in _import_structure:
            if 'dummy' in key.lower():
                continue
            values = _import_structure[key]
            for value in values:
                # pretrained
                if any([
                        re.fullmatch(name, value)
                        for name in transformers_include_names
                ]):
                    try:
                        module = importlib.import_module(
                            f'.{key}', transformers.__name__)
                        value = getattr(module, value)
                        all_imported_modules.append(value)
                    except:  # noqa
                        pass

    if importlib.util.find_spec('peft') is not None:
        try:
            import peft
        except:  # noqa
            pass
        else:
            attributes = dir(peft)
            imports = [
                attr for attr in attributes if not attr.startswith('__')
            ]
            all_imported_modules.extend([
                getattr(peft, _import) for _import in imports if any([
                    re.fullmatch(name, _import) for name in peft_include_names
                ])
            ])

    if importlib.util.find_spec('diffusers') is not None:
        try:
            import diffusers
        except:  # noqa
            pass
        else:
            lazy_module = sys.modules['diffusers']
            if hasattr(lazy_module, '_import_structure'):
                _import_structure = lazy_module._import_structure
                for key in _import_structure:
                    if 'dummy' in key.lower():
                        continue
                    values = _import_structure[key]
                    for value in values:
                        if any([
                                re.fullmatch(name, value)
                                for name in diffusers_include_names
                        ]):
                            try:
                                module = importlib.import_module(
                                    f'.{key}', diffusers.__name__)
                                value = getattr(module, value)
                                all_imported_modules.append(value)
                            except:  # noqa
                                pass
            else:
                attributes = dir(lazy_module)
                imports = [
                    attr for attr in attributes if not attr.startswith('__')
                ]
                all_imported_modules.extend([
                    getattr(lazy_module, _import) for _import in imports
                    if any([
                        re.fullmatch(name, _import)
                        for name in diffusers_include_names
                    ])
                ])
    return all_imported_modules


def _patch_pretrained_class(all_imported_modules, wrap=False):
    """Patch all class to download from modelscope

    Args:
        wrap: Wrap the class or monkey patch the original class

    Returns:
        The classes after patched
    """

    def get_model_dir(pretrained_model_name_or_path,
                      ignore_file_pattern=None,
                      allow_file_pattern=None,
                      **kwargs):
        from modelscope import snapshot_download
        subfolder = kwargs.pop('subfolder', None)
        file_filter = None
        if subfolder:
            file_filter = f'{subfolder}/*'
        if not os.path.exists(pretrained_model_name_or_path):
            revision = kwargs.pop('revision', None)
            if revision is None or revision == 'main':
                revision = 'master'
            if file_filter is not None:
                allow_file_pattern = file_filter
            model_dir = snapshot_download(
                pretrained_model_name_or_path,
                revision=revision,
                ignore_file_pattern=ignore_file_pattern,
                allow_file_pattern=allow_file_pattern)
            if subfolder:
                model_dir = os.path.join(model_dir, subfolder)
        else:
            model_dir = pretrained_model_name_or_path
        return model_dir

    def patch_pretrained_model_name_or_path(cls, pretrained_model_name_or_path,
                                            *model_args, **kwargs):
        """Patch all from_pretrained"""
        model_dir = get_model_dir(pretrained_model_name_or_path,
                                  kwargs.pop('ignore_file_pattern', None),
                                  kwargs.pop('allow_file_pattern', None),
                                  **kwargs)
        return cls._from_pretrained_origin.__func__(cls, model_dir,
                                                    *model_args, **kwargs)

    def patch_get_config_dict(cls, pretrained_model_name_or_path, *model_args,
                              **kwargs):
        """Patch all get_config_dict"""
        model_dir = get_model_dir(pretrained_model_name_or_path,
                                  kwargs.pop('ignore_file_pattern', None),
                                  kwargs.pop('allow_file_pattern', None),
                                  **kwargs)
        return cls._get_config_dict_origin.__func__(cls, model_dir,
                                                    *model_args, **kwargs)

    def patch_peft_model_id(cls, model, model_id, *model_args, **kwargs):
        """Patch all peft.from_pretrained"""
        model_dir = get_model_dir(model_id,
                                  kwargs.pop('ignore_file_pattern', None),
                                  kwargs.pop('allow_file_pattern', None),
                                  **kwargs)
        return cls._from_pretrained_origin.__func__(cls, model, model_dir,
                                                    *model_args, **kwargs)

    def patch_get_peft_type(cls, model_id, **kwargs):
        """Patch all _get_peft_type"""
        model_dir = get_model_dir(model_id,
                                  kwargs.pop('ignore_file_pattern', None),
                                  kwargs.pop('allow_file_pattern', None),
                                  **kwargs)
        return cls._get_peft_type_origin.__func__(cls, model_dir, **kwargs)

    def get_wrapped_class(
            module_class: 'PreTrainedModel',
            ignore_file_pattern: Optional[Union[str, List[str]]] = None,
            allow_file_pattern: Optional[Union[str, List[str]]] = None,
            **kwargs):
        """Get a custom wrapper class for  auto classes to download the models from the ModelScope hub
        Args:
            module_class (`PreTrainedModel`): The actual module class
            ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
                Any file pattern to be ignored, like exact file names or file extensions.
            allow_file_pattern (`str` or `List`, *optional*, default to `None`):
                Any file pattern to be included, like exact file names or file extensions.
        Returns:
            The wrapped class
        """

        @contextlib.contextmanager
        def file_pattern_context(kwargs, module_class, cls):
            if 'allow_file_pattern' not in kwargs:
                kwargs['allow_file_pattern'] = allow_file_pattern
            if 'ignore_file_pattern' not in kwargs:
                kwargs['ignore_file_pattern'] = ignore_file_pattern

            if kwargs.get(
                    'allow_file_pattern') is None and module_class is not None:
                extra_allow_file_pattern = None
                if 'GenerationConfig' == module_class.__name__:
                    from transformers.utils import GENERATION_CONFIG_NAME
                    extra_allow_file_pattern = [
                        GENERATION_CONFIG_NAME, r'*.py'
                    ]
                elif 'Config' in module_class.__name__:
                    from transformers import CONFIG_NAME
                    extra_allow_file_pattern = [CONFIG_NAME, r'*.py']
                elif 'Tokenizer' in module_class.__name__:
                    extra_allow_file_pattern = list(
                        (cls.vocab_files_names.values()) if cls is not None
                        and hasattr(cls, 'vocab_files_names') else []) + [
                            'chat_template.jinja', r'*.json', r'*.py',
                            r'*.txt', r'*.model', r'*.tiktoken'
                        ]  # noqa
                elif 'Processor' in module_class.__name__:
                    extra_allow_file_pattern = [
                        'chat_template.jinja', r'*.json', r'*.py', r'*.txt',
                        r'*.model', r'*.tiktoken'
                    ]

                kwargs['allow_file_pattern'] = extra_allow_file_pattern
            yield
            kwargs.pop('ignore_file_pattern', None)
            kwargs.pop('allow_file_pattern', None)

        def from_pretrained(model, model_id, *model_args, **kwargs):

            with file_pattern_context(kwargs):
                # model is an instance
                model_dir = get_model_dir(
                    model_id,
                    module_class=module_class,
                    cls=module_class,
                    **kwargs)

            module_obj = module_class.from_pretrained(model, model_dir,
                                                      *model_args, **kwargs)

            return module_obj

        class ClassWrapper(module_class):

            @classmethod
            def from_pretrained(cls, pretrained_model_name_or_path,
                                *model_args, **kwargs):
                with file_pattern_context(kwargs, module_class, cls):
                    model_dir = get_model_dir(pretrained_model_name_or_path,
                                              **kwargs)

                module_obj = module_class.from_pretrained(
                    model_dir, *model_args, **kwargs)

                if module_class.__name__.startswith('AutoModel'):
                    module_obj.model_dir = model_dir
                return module_obj

            @classmethod
            def _get_peft_type(cls, model_id, **kwargs):
                with file_pattern_context(kwargs, module_class, cls):
                    model_dir = get_model_dir(
                        model_id,
                        ignore_file_pattern=ignore_file_pattern,
                        allow_file_pattern=allow_file_pattern,
                        **kwargs)

                module_obj = module_class._get_peft_type(model_dir, **kwargs)
                return module_obj

            @classmethod
            def get_config_dict(cls, pretrained_model_name_or_path,
                                *model_args, **kwargs):
                with file_pattern_context(kwargs, module_class, cls):
                    model_dir = get_model_dir(
                        pretrained_model_name_or_path,
                        ignore_file_pattern=ignore_file_pattern,
                        allow_file_pattern=allow_file_pattern,
                        **kwargs)

                module_obj = module_class.get_config_dict(
                    model_dir, *model_args, **kwargs)
                return module_obj

            def save_pretrained(
                self,
                save_directory: Union[str, os.PathLike],
                safe_serialization: bool = True,
                **kwargs,
            ):
                push_to_hub = kwargs.pop('push_to_hub', False)
                if push_to_hub:
                    from modelscope.hub.push_to_hub import push_to_hub
                    from modelscope.hub.api import HubApi
                    from modelscope.hub.repository import Repository

                    token = kwargs.get('token')
                    commit_message = kwargs.pop('commit_message', None)
                    repo_name = kwargs.pop(
                        'repo_id',
                        save_directory.split(os.path.sep)[-1])

                    api = HubApi()
                    api.login(token)
                    api.create_repo(repo_name)
                    # clone the repo
                    Repository(save_directory, repo_name)

                super().save_pretrained(
                    save_directory=save_directory,
                    safe_serialization=safe_serialization,
                    push_to_hub=False,
                    **kwargs)

                # Class members may be unpatched, so push_to_hub is done separately here
                if push_to_hub:
                    push_to_hub(
                        repo_name=repo_name,
                        output_dir=save_directory,
                        commit_message=commit_message,
                        token=token)

        if not hasattr(module_class, 'from_pretrained'):
            del ClassWrapper.from_pretrained
        else:
            parameters = inspect.signature(var.from_pretrained).parameters
            if 'model' in parameters and 'model_id' in parameters:
                # peft
                ClassWrapper.from_pretrained = from_pretrained

        if not hasattr(module_class, '_get_peft_type'):
            del ClassWrapper._get_peft_type

        if not hasattr(module_class, 'get_config_dict'):
            del ClassWrapper.get_config_dict

        if not hasattr(module_class, 'save_pretrained'):
            del ClassWrapper.save_pretrained

        ClassWrapper.__name__ = module_class.__name__
        ClassWrapper.__qualname__ = module_class.__qualname__
        return ClassWrapper

    all_available_modules = []
    for var in all_imported_modules:
        if var is None or not hasattr(var, '__name__'):
            continue
        name = var.__name__
        skip_model = 'tokenizer' in name.lower() or 'config' in name.lower()
        if not skip_model:
            ignore_file_pattern_kwargs = {}
        else:
            ignore_file_pattern_kwargs = {
                'ignore_file_pattern': ignore_file_pattern
            }

        try:
            # some TFxxx classes has import errors
            has_from_pretrained = hasattr(var, 'from_pretrained')
            has_get_peft_type = hasattr(var, '_get_peft_type')
            has_get_config_dict = hasattr(var, 'get_config_dict')
            has_save_pretrained = hasattr(var, 'save_pretrained')
        except:  # noqa
            continue

        # save_pretrained is not a classmethod and cannot be overridden by replacing
        # the class method. It requires replacing the class object method.
        if wrap or ('pipeline' in name.lower() and has_save_pretrained):
            try:
                if (not has_from_pretrained and not has_get_config_dict
                        and not has_get_peft_type and not has_save_pretrained):
                    all_available_modules.append(var)
                else:
                    all_available_modules.append(
                        get_wrapped_class(var, **ignore_file_pattern_kwargs))
            except:  # noqa
                all_available_modules.append(var)
        else:
            if has_from_pretrained and not hasattr(var,
                                                   '_from_pretrained_origin'):
                parameters = inspect.signature(var.from_pretrained).parameters
                # different argument names
                is_peft = 'model' in parameters and 'model_id' in parameters
                var._from_pretrained_origin = var.from_pretrained
                if not is_peft:
                    var.from_pretrained = classmethod(
                        partial(patch_pretrained_model_name_or_path,
                                **ignore_file_pattern_kwargs))
                else:
                    var.from_pretrained = classmethod(
                        partial(patch_peft_model_id,
                                **ignore_file_pattern_kwargs))
            if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
                var._get_peft_type_origin = var._get_peft_type
                var._get_peft_type = classmethod(
                    partial(patch_get_peft_type, **ignore_file_pattern_kwargs))

            if has_get_config_dict and not hasattr(var,
                                                   '_get_config_dict_origin'):
                var._get_config_dict_origin = var.get_config_dict
                var.get_config_dict = classmethod(
                    partial(patch_get_config_dict,
                            **ignore_file_pattern_kwargs))

            all_available_modules.append(var)

    def get_class_from_dynamic_module(class_reference, *args, **kwargs):
        from transformers.dynamic_module_utils import origin_get_class_from_dynamic_module
        if '--' in class_reference:
            repo_id, class_reference = class_reference.split('--')
            if not os.path.exists(repo_id):
                from modelscope import snapshot_download
                repo_id = snapshot_download(repo_id)
            class_reference = repo_id + '--' + class_reference
        return origin_get_class_from_dynamic_module(class_reference, *args,
                                                    **kwargs)

    from transformers import dynamic_module_utils
    if not hasattr(dynamic_module_utils,
                   'origin_get_class_from_dynamic_module'):
        dynamic_module_utils.origin_get_class_from_dynamic_module = dynamic_module_utils.get_class_from_dynamic_module
        dynamic_module_utils.get_class_from_dynamic_module = get_class_from_dynamic_module
        from transformers.models.auto import configuration_auto
        configuration_auto.get_class_from_dynamic_module = get_class_from_dynamic_module
    return all_available_modules


def _unpatch_pretrained_class(all_imported_modules):
    for var in all_imported_modules:
        if var is None:
            continue

        try:
            has_from_pretrained = hasattr(var, 'from_pretrained')
            has_get_peft_type = hasattr(var, '_get_peft_type')
            has_get_config_dict = hasattr(var, 'get_config_dict')
        except:  # noqa
            continue
        if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
            var.from_pretrained = var._from_pretrained_origin
            try:
                delattr(var, '_from_pretrained_origin')
            except:  # noqa
                pass
        if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
            var._get_peft_type = var._get_peft_type_origin
            try:
                delattr(var, '_get_peft_type_origin')
            except:  # noqa
                pass
        if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
            var.get_config_dict = var._get_config_dict_origin
            try:
                delattr(var, '_get_config_dict_origin')
            except:  # noqa
                pass

    from transformers import dynamic_module_utils
    if hasattr(dynamic_module_utils, 'origin_get_class_from_dynamic_module'):
        dynamic_module_utils.get_class_from_dynamic_module = dynamic_module_utils.origin_get_class_from_dynamic_module
        from transformers.models.auto import configuration_auto
        configuration_auto.get_class_from_dynamic_module = dynamic_module_utils.origin_get_class_from_dynamic_module
        delattr(dynamic_module_utils, 'origin_get_class_from_dynamic_module')


def _patch_hub():
    import huggingface_hub
    from huggingface_hub import hf_api
    from huggingface_hub.hf_api import api
    from huggingface_hub.hf_api import future_compatible
    from modelscope import get_logger
    logger = get_logger()

    def _file_exists(
        self,
        repo_id: str,
        filename: str,
        *,
        repo_type: Optional[str] = None,
        revision: Optional[str] = None,
        token: Union[str, bool, None] = None,
    ):
        """Patch huggingface_hub.file_exists"""
        if repo_type is not None:
            logger.warning(
                'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
            )
        from modelscope.hub.api import HubApi
        api = HubApi()
        api.login(token)
        if revision is None or revision == 'main':
            revision = 'master'
        return api.file_exists(repo_id, filename, revision=revision)

    def _file_download(repo_id: str,
                       filename: str,
                       *,
                       subfolder: Optional[str] = None,
                       repo_type: Optional[str] = None,
                       revision: Optional[str] = None,
                       cache_dir: Union[str, Path, None] = None,
                       local_dir: Union[str, Path, None] = None,
                       token: Union[bool, str, None] = None,
                       local_files_only: bool = False,
                       **kwargs):
        """Patch huggingface_hub.hf_hub_download"""
        if len(kwargs) > 0:
            logger.warning(
                'The passed in library_name,library_version,user_agent,force_download,proxies'
                'etag_timeout,headers,endpoint '
                'will not be used in modelscope.')
        assert repo_type in (
            None, 'model',
            'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
        if repo_type in (None, 'model'):
            from modelscope.hub.file_download import model_file_download as file_download
        else:
            from modelscope.hub.file_download import dataset_file_download as file_download
        from modelscope import HubApi
        api = HubApi()
        api.login(token)
        if revision is None or revision == 'main':
            revision = 'master'
        return file_download(
            repo_id,
            file_path=os.path.join(subfolder, filename)
            if subfolder else filename,
            cache_dir=cache_dir,
            local_dir=local_dir,
            local_files_only=local_files_only,
            revision=revision)

    def _whoami(self, token: Union[bool, str, None] = None) -> Dict:
        from modelscope.hub.api import ModelScopeConfig
        from modelscope.hub.api import HubApi
        api = HubApi()
        api.login(token)
        return {'name': ModelScopeConfig.get_user_info()[0] or 'unknown'}

    def create_repo(self,
                    repo_id: str,
                    *,
                    token: Union[str, bool, None] = None,
                    private: bool = False,
                    **kwargs) -> 'RepoUrl':
        """
        Create a new repository on the hub.

        Args:
            repo_id: The ID of the repository to create.
            token: The authentication token to use.
            private: Whether the repository should be private.
            **kwargs: Additional arguments.

        Returns:
            RepoUrl: The URL of the created repository.
        """
        from modelscope.hub.api import HubApi
        api = HubApi()
        visibility = 'private' if private else 'public'
        repo_url = api.create_repo(
            repo_id, token=token, visibility=visibility, **kwargs)
        from modelscope.utils.repo_utils import RepoUrl
        return RepoUrl(url=repo_url, repo_type='model', repo_id=repo_id)

    @future_compatible
    def upload_folder(
        self,
        *,
        repo_id: str,
        folder_path: Union[str, Path],
        path_in_repo: Optional[str] = None,
        commit_message: Optional[str] = None,
        commit_description: Optional[str] = None,
        token: Union[str, bool, None] = None,
        revision: Optional[str] = 'master',
        ignore_patterns: Optional[Union[List[str], str]] = None,
        **kwargs,
    ):
        from modelscope.hub.push_to_hub import _push_files_to_hub
        if revision is None or revision == 'main':
            revision = 'master'
        _push_files_to_hub(
            path_or_fileobj=folder_path,
            path_in_repo=path_in_repo,
            repo_id=repo_id,
            commit_message=commit_message,
            commit_description=commit_description,
            revision=revision,
            token=token)
        from modelscope.utils.repo_utils import CommitInfo
        return CommitInfo(
            commit_url=
            f'{DEFAULT_MODELSCOPE_DATA_ENDPOINT}/models/{repo_id}/files',
            commit_message=commit_message,
            commit_description=commit_description,
            oid=None,
        )

    from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION

    @future_compatible
    def upload_file(
        self,
        *,
        path_or_fileobj: Union[str, Path, bytes, BinaryIO],
        path_in_repo: str,
        repo_id: str,
        token: Union[str, bool, None] = None,
        revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
        commit_message: Optional[str] = None,
        commit_description: Optional[str] = None,
        **kwargs,
    ):
        if revision is None or revision == 'main':
            revision = 'master'
        from modelscope.hub.push_to_hub import _push_files_to_hub
        _push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
                           revision, commit_message, commit_description)

    @future_compatible
    def create_commit(
        self,
        repo_id: str,
        operations: Iterable[CommitOperation],
        *,
        commit_message: str,
        commit_description: Optional[str] = None,
        token: Union[str, bool, None] = None,
        repo_type: Optional[str] = None,
        revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
        **kwargs,
    ) -> Union[CommitInfo, Future[CommitInfo]]:
        from modelscope.hub.api import HubApi
        api = HubApi()
        if any(['Add' not in op.__class__.__name__ for op in operations]):
            raise ValueError(
                'ModelScope create_commit only support Add operation for now.')
        if revision is None or revision == 'main':
            revision = 'master'
        all_files = [op.path_or_fileobj for op in operations]
        api.upload_folder(
            repo_id=repo_id,
            folder_path=all_files,
            commit_message=commit_message,
            commit_description=commit_description,
            token=token,
            revision=revision,
            repo_type=repo_type or 'model')

    def load(
        cls,
        repo_id_or_path: Union[str, Path],
        repo_type: Optional[str] = None,
        token: Optional[str] = None,
        ignore_metadata_errors: bool = False,
    ):
        from modelscope.hub.api import HubApi
        api = HubApi()
        api.login(token)
        if os.path.exists(repo_id_or_path):
            file_path = repo_id_or_path
        elif repo_type == 'model' or repo_type is None:
            from modelscope import model_file_download
            file_path = model_file_download(repo_id_or_path, 'README.md')
        elif repo_type == 'dataset':
            from modelscope import dataset_file_download
            file_path = dataset_file_download(repo_id_or_path, 'README.md')
        else:
            raise ValueError(
                f'repo_type should be `model` or `dataset`, but now is {repo_type}'
            )

        with open(file_path, 'r') as f:
            repo_card = cls(
                f.read(), ignore_metadata_errors=ignore_metadata_errors)
            if not hasattr(repo_card.data, 'tags'):
                repo_card.data.tags = []
            return repo_card

    # Patch repocard.validate
    from huggingface_hub import repocard
    if not hasattr(repocard.RepoCard, '_validate_origin'):
        repocard.RepoCard._validate_origin = repocard.RepoCard.validate
        repocard.RepoCard.validate = lambda *args, **kwargs: None
        repocard.RepoCard._load_origin = repocard.RepoCard.load
        repocard.RepoCard.load = MethodType(load, repocard.RepoCard)

    if not hasattr(hf_api, '_hf_hub_download_origin'):
        # Patch hf_hub_download
        hf_api._hf_hub_download_origin = huggingface_hub.file_download.hf_hub_download
        huggingface_hub.hf_hub_download = _file_download
        huggingface_hub.file_download.hf_hub_download = _file_download

    if not hasattr(hf_api, '_file_exists_origin'):
        # Patch file_exists
        hf_api._file_exists_origin = hf_api.file_exists
        hf_api.file_exists = MethodType(_file_exists, api)
        huggingface_hub.file_exists = hf_api.file_exists
        huggingface_hub.hf_api.file_exists = hf_api.file_exists

    if not hasattr(hf_api, '_whoami_origin'):
        # Patch whoami
        hf_api._whoami_origin = hf_api.whoami
        hf_api.whoami = MethodType(_whoami, api)
        huggingface_hub.whoami = hf_api.whoami
        huggingface_hub.hf_api.whoami = hf_api.whoami

    if not hasattr(hf_api, '_create_repo_origin'):
        # Patch create_repo
        from transformers.utils import hub
        hf_api._create_repo_origin = hf_api.create_repo
        hf_api.create_repo = MethodType(create_repo, api)
        huggingface_hub.create_repo = hf_api.create_repo
        huggingface_hub.hf_api.create_repo = hf_api.create_repo
        hub.create_repo = hf_api.create_repo

    if not hasattr(hf_api, '_upload_folder_origin'):
        # Patch upload_folder
        hf_api._upload_folder_origin = hf_api.upload_folder
        hf_api.upload_folder = MethodType(upload_folder, api)
        huggingface_hub.upload_folder = hf_api.upload_folder
        huggingface_hub.hf_api.upload_folder = hf_api.upload_folder

    if not hasattr(hf_api, '_upload_file_origin'):
        # Patch upload_file
        hf_api._upload_file_origin = hf_api.upload_file
        hf_api.upload_file = MethodType(upload_file, api)
        huggingface_hub.upload_file = hf_api.upload_file
        huggingface_hub.hf_api.upload_file = hf_api.upload_file
        repocard.upload_file = hf_api.upload_file

    if not hasattr(hf_api, '_create_commit_origin'):
        # Patch upload_file
        hf_api._create_commit_origin = hf_api.create_commit
        hf_api.create_commit = MethodType(create_commit, api)
        huggingface_hub.create_commit = hf_api.create_commit
        huggingface_hub.hf_api.create_commit = hf_api.create_commit
        from transformers.utils import hub
        hub.create_commit = hf_api.create_commit


def _unpatch_hub():
    import huggingface_hub
    from huggingface_hub import hf_api

    from huggingface_hub import repocard
    if hasattr(repocard.RepoCard, '_validate_origin'):
        repocard.RepoCard.validate = repocard.RepoCard._validate_origin
        delattr(repocard.RepoCard, '_validate_origin')
    if hasattr(repocard.RepoCard, '_load_origin'):
        repocard.RepoCard.load = repocard.RepoCard._load_origin
        delattr(repocard.RepoCard, '_load_origin')

    if hasattr(hf_api, '_hf_hub_download_origin'):
        huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
        huggingface_hub.hf_hub_download = hf_api._hf_hub_download_origin
        huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
        delattr(hf_api, '_hf_hub_download_origin')

    if hasattr(hf_api, '_file_exists_origin'):
        hf_api.file_exists = hf_api._file_exists_origin
        huggingface_hub.file_exists = hf_api.file_exists
        huggingface_hub.hf_api.file_exists = hf_api.file_exists
        delattr(hf_api, '_file_exists_origin')

    if hasattr(hf_api, '_whoami_origin'):
        hf_api.whoami = hf_api._whoami_origin
        huggingface_hub.whoami = hf_api.whoami
        huggingface_hub.hf_api.whoami = hf_api.whoami
        delattr(hf_api, '_whoami_origin')

    if hasattr(hf_api, '_create_repo_origin'):
        from transformers.utils import hub
        hf_api.create_repo = hf_api._create_repo_origin
        huggingface_hub.create_repo = hf_api.create_repo
        huggingface_hub.hf_api.create_repo = hf_api.create_repo
        hub.create_repo = hf_api.create_repo
        delattr(hf_api, '_create_repo_origin')

    if hasattr(hf_api, '_upload_folder_origin'):
        hf_api.upload_folder = hf_api._upload_folder_origin
        huggingface_hub.upload_folder = hf_api.upload_folder
        huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
        delattr(hf_api, '_upload_folder_origin')

    if hasattr(hf_api, '_upload_file_origin'):
        hf_api.upload_file = hf_api._upload_file_origin
        huggingface_hub.upload_file = hf_api.upload_file
        huggingface_hub.hf_api.upload_file = hf_api.upload_file
        repocard.upload_file = hf_api.upload_file
        delattr(hf_api, '_upload_file_origin')

    if hasattr(hf_api, '_create_commit_origin'):
        hf_api.create_commit = hf_api._create_commit_origin
        huggingface_hub.create_commit = hf_api.create_commit
        huggingface_hub.hf_api.create_commit = hf_api.create_commit
        from transformers.utils import hub
        hub.create_commit = hf_api.create_commit
        delattr(hf_api, '_create_commit_origin')


def patch_hub():
    _patch_hub()
    _patch_pretrained_class(get_all_imported_modules())


def unpatch_hub():
    _unpatch_pretrained_class(get_all_imported_modules())
    _unpatch_hub()


@contextlib.contextmanager
def patch_context():
    patch_hub()
    yield
    unpatch_hub()
