# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import sys
import tarfile
import tempfile
import time
import zipfile

import requests

__all__ = ["download", "extract", "download_and_extract"]


class _ProgressPrinter(object):
    """ProgressPrinter"""

    def __init__(self, flush_interval=0.1):
        super().__init__()
        self._last_time = 0
        self._flush_intvl = flush_interval

    def print(self, str_, end=False):
        """print"""
        if end:
            str_ += "\n"
            self._last_time = 0
        if time.time() - self._last_time >= self._flush_intvl:
            sys.stderr.write(f"\r{str_}")
            self._last_time = time.time()
            sys.stderr.flush()


def _download(url, save_path, print_progress):
    if print_progress:
        print(f"Connecting to {url} ...", file=sys.stderr)

    with requests.get(url, stream=True, timeout=15) as r:
        r.raise_for_status()

        total_length = r.headers.get("content-length")

        if total_length is None:
            with open(save_path, "wb") as f:
                shutil.copyfileobj(r.raw, f)
        else:
            with open(save_path, "wb") as f:
                dl = 0
                total_length = int(total_length)
                if print_progress:
                    printer = _ProgressPrinter()
                    print(
                        f"Downloading {os.path.basename(save_path)} ...",
                        file=sys.stderr,
                    )
                for data in r.iter_content(chunk_size=4096):
                    dl += len(data)
                    f.write(data)
                    if print_progress:
                        done = int(50 * dl / total_length)
                        printer.print(
                            f"[{'=' * done:<50s}] {float(100 * dl) / total_length:.2f}%"
                        )
            if print_progress:
                printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)


def _extract_zip_file(file_path, extd_dir):
    """extract zip file"""
    with zipfile.ZipFile(file_path, "r") as f:
        file_list = f.namelist()
        total_num = len(file_list)
        for index, file in enumerate(file_list):
            f.extract(file, extd_dir)
            yield total_num, index


def _extract_tar_file(file_path, extd_dir):
    """extract tar file"""
    try:
        with tarfile.open(file_path, "r:*") as f:
            file_list = f.getnames()
            total_num = len(file_list)
            for index, file in enumerate(file_list):
                try:
                    f.extract(file, extd_dir)
                except KeyError:
                    print(f"File {file} not found in the archive.", file=sys.stderr)
                yield total_num, index
    except Exception as e:
        print(f"An error occurred: {e}", file=sys.stderr)


def _extract(file_path, extd_dir, print_progress):
    """extract"""
    if print_progress:
        printer = _ProgressPrinter()
        print(f"Extracting {os.path.basename(file_path)}", file=sys.stderr)

    if zipfile.is_zipfile(file_path):
        handler = _extract_zip_file
    elif tarfile.is_tarfile(file_path):
        handler = _extract_tar_file
    else:
        raise RuntimeError("Unsupported file format.")

    for total_num, index in handler(file_path, extd_dir):
        if print_progress:
            done = int(50 * float(index) / total_num)
            printer.print(f"[{'=' * done:<50s}] {float(100 * index) / total_num:.2f}%")
    if print_progress:
        printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)


def _remove_if_exists(path):
    """remove"""
    if os.path.exists(path):
        if os.path.isdir(path):
            shutil.rmtree(path)
        else:
            os.remove(path)


def download(url, save_path, print_progress=True, overwrite=False):
    """download"""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    if overwrite:
        _remove_if_exists(save_path)
    if not os.path.exists(save_path):
        _download(url, save_path, print_progress=print_progress)


def extract(file_path, extd_dir, print_progress=True):
    """extract"""
    return _extract(file_path, extd_dir, print_progress=print_progress)


def download_and_extract(
    url, save_dir, dst_name, print_progress=True, overwrite=False, no_interm_dir=True
):
    """download and extract"""
    # NOTE: `url` MUST come from a trusted source, since we do not provide a solution
    # to secure against CVE-2007-4559.
    os.makedirs(save_dir, exist_ok=True)
    dst_path = os.path.join(save_dir, dst_name)
    if overwrite:
        _remove_if_exists(dst_path)

    if not os.path.exists(dst_path):
        with tempfile.TemporaryDirectory() as td:
            arc_file_path = os.path.join(td, url.split("/")[-1])
            extd_dir = os.path.splitext(arc_file_path)[0]
            _download(url, arc_file_path, print_progress=print_progress)
            tmp_extd_dir = os.path.join(td, "extract")
            _extract(arc_file_path, tmp_extd_dir, print_progress=print_progress)
            if no_interm_dir:
                file_names = os.listdir(tmp_extd_dir)
                if len(file_names) == 1:
                    file_name = file_names[0]
                else:
                    file_name = dst_name
                sp = os.path.join(tmp_extd_dir, file_name)
                if not os.path.exists(sp):
                    raise FileNotFoundError
                dp = os.path.join(save_dir, file_name)
                if os.path.isdir(sp):
                    shutil.copytree(sp, dp, symlinks=True)
                else:
                    shutil.copyfile(sp, dp)
                extd_file = dp
            else:
                shutil.copytree(tmp_extd_dir, extd_dir)
                extd_file = extd_dir

            if not os.path.exists(dst_path) or not os.path.samefile(
                extd_file, dst_path
            ):
                shutil.move(extd_file, dst_path)
