# noqa: isort:skip_file, yapf: disable
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.

import json
import os
import re
import copy
import shutil
import time
import warnings
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Optional, Union
from urllib.parse import urljoin, urlparse
import requests
from tqdm.auto import tqdm

from datasets import config
from datasets.utils.file_utils import hash_url_to_filename, \
    get_authentication_headers_for_url, fsspec_head, fsspec_get
from filelock import FileLock

from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.logger import get_logger
from modelscope.hub.api import ModelScopeConfig

from modelscope import __version__

logger = get_logger()


def get_datasets_user_agent_ms(user_agent: Optional[Union[str, dict]] = None) -> str:
    ua = f'datasets/{__version__}'
    ua += f'; python/{config.PY_VERSION}'
    ua += f'; pyarrow/{config.PYARROW_VERSION}'
    if config.TORCH_AVAILABLE:
        ua += f'; torch/{config.TORCH_VERSION}'
    if config.TF_AVAILABLE:
        ua += f'; tensorflow/{config.TF_VERSION}'
    if config.JAX_AVAILABLE:
        ua += f'; jax/{config.JAX_VERSION}'
    if isinstance(user_agent, dict):
        ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}"
    elif isinstance(user_agent, str):
        ua += '; ' + user_agent
    return ua


def _request_with_retry_ms(
    method: str,
    url: str,
    max_retries: int = 2,
    base_wait_time: float = 0.5,
    max_wait_time: float = 2,
    timeout: float = 10.0,
    **params,
) -> requests.Response:
    """Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.

    Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.

    Args:
        method (str): HTTP method, such as 'GET' or 'HEAD'.
        url (str): The URL of the resource to fetch.
        max_retries (int): Maximum number of retries, defaults to 0 (no retries).
        base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
            retries then grows exponentially, capped by max_wait_time.
        max_wait_time (float): Maximum amount of time between two retries, in seconds.
        **params (additional keyword arguments): Params to pass to :obj:`requests.request`.
    """
    tries, success = 0, False
    response = None
    while not success:
        tries += 1
        try:
            response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
            success = True
        except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
            if tries > max_retries:
                raise err
            else:
                logger.info(f'{method} request to {url} timed out, retrying... [{tries/max_retries}]')
                sleep_time = min(max_wait_time, base_wait_time * 2 ** (tries - 1))  # Exponential backoff
                time.sleep(sleep_time)
    return response


def http_head_ms(
    url, proxies=None, headers=None, cookies=None, allow_redirects=True, timeout=10.0, max_retries=0
) -> requests.Response:
    headers = copy.deepcopy(headers) or {}
    headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
    response = _request_with_retry_ms(
        method='HEAD',
        url=url,
        proxies=proxies,
        headers=headers,
        cookies=cookies,
        allow_redirects=allow_redirects,
        timeout=timeout,
        max_retries=max_retries,
    )
    return response


def http_get_ms(
    url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
) -> Optional[requests.Response]:
    headers = dict(headers) if headers is not None else {}
    headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
    if resume_size > 0:
        headers['Range'] = f'bytes={resume_size:d}-'
    response = _request_with_retry_ms(
        method='GET',
        url=url,
        stream=True,
        proxies=proxies,
        headers=headers,
        cookies=cookies,
        max_retries=max_retries,
        timeout=timeout,
    )
    if temp_file is None:
        return response
    if response.status_code == 416:  # Range not satisfiable
        return
    content_length = response.headers.get('Content-Length')
    total = resume_size + int(content_length) if content_length is not None else None

    progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
    for chunk in response.iter_content(chunk_size=1024):
        progress.update(len(chunk))
        temp_file.write(chunk)

    progress.close()


def get_from_cache_ms(
    url,
    cache_dir=None,
    force_download=False,
    proxies=None,
    etag_timeout=100,
    resume_download=False,
    user_agent=None,
    local_files_only=False,
    use_etag=True,
    max_retries=0,
    token=None,
    use_auth_token='deprecated',
    ignore_url_params=False,
    storage_options=None,
    download_desc=None,
    disable_tqdm=None,
) -> str:
    """
    Given a URL, look for the corresponding file in the local cache.
    If it's not there, download it. Then return the path to the cached file.

    Return:
        Local path (string)

    Raises:
        FileNotFoundError: in case of non-recoverable file
            (non-existent or no cache on disk)
        ConnectionError: in case of unreachable url
            and no cache on disk
    """
    if use_auth_token != 'deprecated':
        warnings.warn(
            "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
            f"You can remove this warning by passing 'token={use_auth_token}' instead.",
            FutureWarning,
        )
        token = use_auth_token
    if cache_dir is None:
        cache_dir = MS_DATASETS_CACHE
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    os.makedirs(cache_dir, exist_ok=True)

    if ignore_url_params:
        # strip all query parameters and #fragments from the URL
        cached_url = urljoin(url, urlparse(url).path)
    else:
        cached_url = url  # additional parameters may be added to the given URL

    connected = False
    response = None
    cookies = None
    etag = None
    head_error = None
    scheme = None

    # Try a first time to file the file on the local file system without eTag (None)
    # if we don't ask for 'force_download' then we spare a request
    filename = hash_url_to_filename(cached_url, etag=None)
    cache_path = os.path.join(cache_dir, filename)
    if download_desc is None:
        download_desc = 'Downloading [' + filename + ']'

    if os.path.exists(cache_path) and not force_download and not use_etag:
        return cache_path

    # Prepare headers for authentication
    headers = get_authentication_headers_for_url(url, token=token)
    if user_agent is not None:
        headers['user-agent'] = user_agent

    # We don't have the file locally or we need an eTag
    if not local_files_only:
        scheme = urlparse(url).scheme
        if scheme not in ('http', 'https'):
            response = fsspec_head(url, storage_options=storage_options)
            # s3fs uses "ETag", gcsfs uses "etag"
            etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
            connected = True
        try:
            cookies = ModelScopeConfig.get_cookies()
            response = http_head_ms(
                url,
                allow_redirects=True,
                proxies=proxies,
                timeout=etag_timeout,
                max_retries=max_retries,
                headers=headers,
                cookies=cookies,
            )
            if response.status_code == 200:  # ok
                etag = response.headers.get('ETag') if use_etag else None
                for k, v in response.cookies.items():
                    # In some edge cases, we need to get a confirmation token
                    if k.startswith('download_warning') and 'drive.google.com' in url:
                        url += '&confirm=' + v
                        cookies = response.cookies
                connected = True
                # Fix Google Drive URL to avoid Virus scan warning
                if 'drive.google.com' in url and 'confirm=' not in url:
                    url += '&confirm=t'
            # In some edge cases, head request returns 400 but the connection is actually ok
            elif (
                (response.status_code == 400 and 'firebasestorage.googleapis.com' in url)
                or (response.status_code == 405 and 'drive.google.com' in url)
                or (
                    response.status_code == 403
                    and (
                        re.match(r'^https?://github.com/.*?/.*?/releases/download/.*?/.*?$', url)
                        or re.match(r'^https://.*?s3.*?amazonaws.com/.*?$', response.url)
                    )
                )
                or (response.status_code == 403 and 'ndownloader.figstatic.com' in url)
            ):
                connected = True
                logger.info(f"Couldn't get ETag version for url {url}")
            elif response.status_code == 401 and config.HF_ENDPOINT in url and token is None:
                raise ConnectionError(
                    f'Unauthorized for URL {url}. '
                    f'Please use the parameter `token=True` after logging in with `huggingface-cli login`'
                )
        except (OSError, requests.exceptions.Timeout) as e:
            # not connected
            head_error = e
            pass

    # connected == False = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
    # try to get the last downloaded one
    if not connected:
        if os.path.exists(cache_path) and not force_download:
            return cache_path
        if local_files_only:
            raise FileNotFoundError(
                f'Cannot find the requested files in the cached path at {cache_path} and outgoing traffic has been'
                " disabled. To enable file online look-ups, set 'local_files_only' to False."
            )
        elif response is not None and response.status_code == 404:
            raise FileNotFoundError(f"Couldn't find file at {url}")
        if head_error is not None:
            raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
        elif response is not None:
            raise ConnectionError(f"Couldn't reach {url} (error {response.status_code})")
        else:
            raise ConnectionError(f"Couldn't reach {url}")

    # Try a second time
    filename = hash_url_to_filename(cached_url, etag)
    cache_path = os.path.join(cache_dir, filename)

    if os.path.exists(cache_path) and not force_download:
        return cache_path

    # From now on, connected is True.
    # Prevent parallel downloads of the same file with a lock.
    lock_path = cache_path + '.lock'
    with FileLock(lock_path):
        # Retry in case previously locked processes just enter after the precedent process releases the lock
        if os.path.exists(cache_path) and not force_download:
            return cache_path

        incomplete_path = cache_path + '.incomplete'

        @contextmanager
        def temp_file_manager(mode='w+b'):
            with open(incomplete_path, mode) as f:
                yield f

        resume_size = 0
        if resume_download:
            temp_file_manager = partial(temp_file_manager, mode='a+b')
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size

        # Download to temporary file, then copy to cache path once finished.
        # Otherwise, you get corrupt cache entries if the download gets interrupted.
        with temp_file_manager() as temp_file:

            # GET file object
            if scheme not in ('http', 'https'):
                fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
            else:
                http_get_ms(
                    url,
                    temp_file=temp_file,
                    proxies=proxies,
                    resume_size=resume_size,
                    headers=headers,
                    cookies=cookies,
                    max_retries=max_retries,
                    desc=download_desc,
                )

        logger.info(f'storing {url} in cache at {cache_path}')
        shutil.move(temp_file.name, cache_path)
        umask = os.umask(0o666)
        os.umask(umask)
        os.chmod(cache_path, 0o666 & ~umask)

        logger.info(f'creating metadata file for {cache_path}')
        meta = {'url': url, 'etag': etag}
        meta_path = cache_path + '.json'
        with open(meta_path, 'w', encoding='utf-8') as meta_file:
            json.dump(meta, meta_file)

    return cache_path
