# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from multiprocessing.dummy import Pool as ThreadPool

from tqdm.auto import tqdm

from modelscope.msdatasets.utils.oss_utils import OssUtilities
from modelscope.utils.constant import UploadMode


class DatasetUploadManager(object):

    def __init__(self, dataset_name: str, namespace: str, version: str):
        from modelscope.hub.api import HubApi
        _hub_api = HubApi()
        _oss_config = _hub_api.get_dataset_access_config_session(
            dataset_name=dataset_name,
            namespace=namespace,
            check_cookie=False,
            revision=version)

        self.oss_utilities = OssUtilities(
            oss_config=_oss_config,
            dataset_name=dataset_name,
            namespace=namespace,
            revision=version)

    def upload(self, object_name: str, local_file_path: str,
               upload_mode: UploadMode) -> str:
        object_key = self.oss_utilities.upload(
            oss_object_name=object_name,
            local_file_path=local_file_path,
            indicate_individual_progress=True,
            upload_mode=upload_mode)
        return object_key

    def upload_dir(self, object_dir_name: str, local_dir_path: str,
                   num_processes: int, chunksize: int,
                   filter_hidden_files: bool, upload_mode: UploadMode) -> int:

        def run_upload(args):
            self.oss_utilities.upload(
                oss_object_name=args[0],
                local_file_path=args[1],
                indicate_individual_progress=False,
                upload_mode=upload_mode)

        files_list = []
        for root, dirs, files in os.walk(local_dir_path):
            for file_name in files:
                if filter_hidden_files and file_name.startswith('.'):
                    continue
                # Concatenate directory name and relative path into oss object key. e.g., train/001/1_1230.png
                object_name = os.path.join(
                    object_dir_name,
                    root.replace(local_dir_path, '', 1).strip('/'), file_name)

                local_file_path = os.path.join(root, file_name)
                files_list.append((object_name, local_file_path))

        with ThreadPool(processes=num_processes) as pool:
            result = list(
                tqdm(
                    pool.imap(run_upload, files_list, chunksize=chunksize),
                    total=len(files_list)))

        return len(result)
