# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Dict, Optional, Union
from urllib.parse import urlparse

from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import FILE_HASH
from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import compute_hash
from modelscope.utils.logger import get_logger

logger = get_logger()


def get_model_id_from_cache(model_root_path: str, ) -> str:
    model_cache = None
    # download with git
    if os.path.exists(os.path.join(model_root_path, '.git')):
        git_cmd_wrapper = GitCommandWrapper()
        git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
        if git_url.endswith('.git'):
            git_url = git_url[:-4]
        u_parse = urlparse(git_url)
        model_id = u_parse.path[1:]
    else:  # snapshot_download
        model_cache = ModelFileSystemCache(model_root_path)
        model_id = model_cache.get_model_id()
    return model_id


def check_local_model_is_latest(
    model_root_path: str,
    user_agent: Optional[Union[Dict, str]] = None,
):
    """Check local model repo is latest.
    Check local model repo is same as hub latest version.
    """
    try:
        model_id = get_model_id_from_cache(model_root_path)
        model_id = model_id.replace('___', '.')
        # make headers
        headers = {
            'user-agent':
            ModelScopeConfig.get_user_agent(user_agent=user_agent, )
        }
        cookies = ModelScopeConfig.get_cookies()

        snapshot_header = headers if 'CI_TEST' in os.environ else {
            **headers,
            **{
                'Snapshot': 'True'
            }
        }
        _api = HubApi(timeout=20)
        try:
            _, revisions = _api.get_model_branches_and_tags(
                model_id=model_id, use_cookies=cookies)
            if len(revisions) > 0:
                latest_revision = revisions[0]
            else:
                latest_revision = 'master'
        except:  # noqa: E722
            latest_revision = 'master'

        model_files = _api.get_model_files(
            model_id=model_id,
            revision=latest_revision,
            recursive=True,
            headers=snapshot_header,
            use_cookies=cookies,
        )
        model_cache = None
        # download via non-git method
        if not os.path.exists(os.path.join(model_root_path, '.git')):
            model_cache = ModelFileSystemCache(model_root_path)
        for model_file in model_files:
            if model_file['Type'] == 'tree':
                continue
            # check model_file updated
            if model_cache is not None:
                if model_cache.exists(model_file):
                    continue
                else:
                    logger.info(
                        f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
                        f'This is because you are using an older version or the file is updated manually.'
                    )
                    break
            else:
                if FILE_HASH in model_file:
                    local_file_hash = compute_hash(
                        os.path.join(model_root_path, model_file['Path']))
                    if local_file_hash == model_file[FILE_HASH]:
                        continue
                    else:
                        logger.info(
                            f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
                            f'This is because you are using an older version or the file is updated manually.'
                        )
                        break
    except:  # noqa: E722
        pass  # ignore


def check_model_is_id(model_id: str, token: Optional[str] = None):
    if model_id is None or os.path.exists(model_id):
        return False
    else:
        _api = HubApi()
        _api.login(token)
        try:
            _api.get_model(model_id=model_id, )
            return True
        except Exception:
            return False
