# Copyright (c) Alibaba, Inc. and its affiliates.

from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps

from tqdm.auto import tqdm

from modelscope.hub.constants import DEFAULT_MAX_WORKERS
from modelscope.utils.logger import get_logger

logger = get_logger()


def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
                    disable_tqdm: bool = False,
                    tqdm_desc: str = None):
    """
    A decorator to execute a function in a threaded manner using ThreadPoolExecutor.

    Args:
        max_workers (int): The maximum number of threads to use.
        disable_tqdm (bool): disable progress bar.
        tqdm_desc (str): Desc of tqdm.

    Returns:
        function: A wrapped function that executes with threading and a progress bar.

    Examples:
        >>> from modelscope.utils.thread_utils import thread_executor
        >>> import time
        >>> @thread_executor(max_workers=8)
        ... def process_item(item, x, y):
        ...     # do something to single item
        ...     time.sleep(1)
        ...     return str(item) + str(x) + str(y)

        >>> items = [1, 2, 3]
        >>> process_item(items, x='abc', y='xyz')
    """

    def decorator(func):

        @wraps(func)
        def wrapper(iterable, *args, **kwargs):
            results = []
            # Create a tqdm progress bar with the total number of items to process
            with tqdm(
                    unit_scale=True,
                    unit_divisor=1024,
                    initial=0,
                    total=len(iterable),
                    desc=tqdm_desc or f'Processing {len(iterable)} items',
                    disable=disable_tqdm,
            ) as pbar:
                # Define a wrapper function to update the progress bar
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    # Submit all tasks
                    futures = {
                        executor.submit(func, item, *args, **kwargs): item
                        for item in iterable
                    }

                    # Update the progress bar as tasks complete
                    for future in as_completed(futures):
                        pbar.update(1)
                        results.append(future.result())
            return results

        return wrapper

    return decorator
