# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from collections import defaultdict
from typing import Optional, Union

import pandas as pd

from modelscope.hub.api import HubApi
from modelscope.msdatasets.context.dataset_context_config import \
    DatasetContextConfig
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, MetaDataFields
from modelscope.utils.logger import get_logger

logger = get_logger()


def format_dataset_structure(dataset_structure):
    return {
        k: v
        for k, v in dataset_structure.items()
        if (v.get('meta') or v.get('file'))
    }


def get_target_dataset_structure(dataset_structure: dict,
                                 subset_name: Optional[str] = None,
                                 split: Optional[str] = None):
    """
    Args:
        dataset_structure (dict): Dataset Structure, like
         {
            "default":{
                "train":{
                    "meta":"my_train.csv",
                    "file":"pictures.zip"
                }
            },
            "subsetA":{
                "test":{
                    "meta":"mytest.csv",
                    "file":"pictures.zip"
                }
            }
        }
        subset_name (str, optional): Defining the subset_name of the dataset.
        split (str, optional): Which split of the data to load.
    Returns:
           target_subset_name (str): Name of the chosen subset.
           target_dataset_structure (dict): Structure of the chosen split(s), like
           {
               "test":{
                        "meta":"mytest.csv",
                        "file":"pictures.zip"
                    }
            }
    """
    # verify dataset subset
    if (subset_name and subset_name not in dataset_structure) or (
            not subset_name and len(dataset_structure.keys()) > 1):
        raise ValueError(
            f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}'
        )
    target_subset_name = subset_name
    if not subset_name:
        target_subset_name = next(iter(dataset_structure.keys()))
        logger.info(
            f'No subset_name specified, defaulting to the {target_subset_name}'
        )
    # verify dataset split
    target_dataset_structure = format_dataset_structure(
        dataset_structure[target_subset_name])
    if split and split not in target_dataset_structure:
        raise ValueError(
            f'split {split} not found. Available: {target_dataset_structure.keys()}'
        )
    if split:
        target_dataset_structure = {split: target_dataset_structure[split]}
    return target_subset_name, target_dataset_structure


def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool,
                         dataset_name: str, namespace: str,
                         version: str) -> list:
    """
    List all objects for specific dataset.

    Args:
        hub_api (class HubApi): HubApi instance.
        max_limit (int): Max number of objects.
        is_recursive (bool): Whether to list objects recursively.
        dataset_name (str): Dataset name.
        namespace (str): Namespace.
        version (str): Dataset version.
    Returns:
        res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...]
    """
    res = []
    objects = hub_api.list_oss_dataset_objects(
        dataset_name=dataset_name,
        namespace=namespace,
        max_limit=max_limit,
        is_recursive=is_recursive,
        is_filter_dir=True,
        revision=version)

    for item in objects:
        object_key = item.get('Key')
        if not object_key:
            continue
        res.append(object_key)

    return res


def contains_dir(file_map) -> bool:
    """
    To check whether input contains at least one directory.

    Args:
        file_map (dict): Structure of data files. e.g., {'train': 'train.zip', 'validation': 'val.zip'}
    Returns:
        True if input contains at least one directory, False otherwise.
    """
    res = False
    for k, v in file_map.items():
        if isinstance(v, str) and not v.endswith('.zip'):
            res = True
            break
    return res


def get_subdir_hash_from_split(split: Union[str, list], version: str) -> str:
    if isinstance(split, str):
        split = [split]
    return os.path.join(version, '_'.join(split))


def get_split_list(split: Union[str, list]) -> list:
    """ Unify the split to list-format. """
    if isinstance(split, str):
        return [split]
    elif isinstance(split, list):
        return split
    else:
        raise f'Expected format of split: str or list, but got {type(split)}.'


def get_split_objects_map(file_map, objects):
    """
    Get the map between dataset split and oss objects.

    Args:
        file_map (dict): Structure of data files. e.g., {'train': 'train', 'validation': 'val'}, both of train and val
            are dirs.
        objects (list): List of oss objects. e.g., ['train/001/1_123.png', 'train/001/1_124.png', 'val/003/3_38.png']
    Returns:
        A map of split-objects. e.g., {'train': ['train/001/1_123.png', 'train/001/1_124.png'],
            'validation':['val/003/3_38.png']}
    """
    res = {}
    for k, v in file_map.items():
        res[k] = []

    for obj_key in objects:
        for k, v in file_map.items():
            if obj_key.startswith(v.rstrip('/') + '/'):
                res[k].append(obj_key)

    return res


def get_dataset_files(subset_split_into: dict,
                      dataset_name: str,
                      namespace: str,
                      context_config: DatasetContextConfig,
                      revision: Optional[str] = DEFAULT_DATASET_REVISION):
    """
    Return:
        meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like
        {
           "test": "https://xxx/mytest.csv"
        }
        file_map: Structure of data files (.zip), like
        {
            "test": "pictures.zip"
        }
    """
    meta_map = defaultdict(dict)
    file_map = defaultdict(dict)
    args_map = defaultdict(dict)
    custom_type_map = defaultdict(dict)
    modelscope_api = HubApi()
    meta_cache_dir = context_config.data_meta_config.meta_cache_dir

    for split, info in subset_split_into.items():
        custom_type_map[split] = info.get('custom', '')
        meta_map[split] = modelscope_api.get_dataset_file_url_origin(
            info.get('meta', ''), dataset_name, namespace, revision)
        if info.get('file'):
            file_map[split] = info['file']
        args_map[split] = info.get('args')

    objects = []
    # If `big_data` is true, then fetch objects from meta-csv file directly.
    for split, args_dict in args_map.items():
        if args_dict and args_dict.get(MetaDataFields.ARGS_BIG_DATA):
            meta_csv_file_url = meta_map[split]

            meta_csv_file_path = HubApi.fetch_meta_files_from_url(
                meta_csv_file_url, meta_cache_dir)

            csv_delimiter = context_config.config_kwargs.get('delimiter', ',')
            csv_df = pd.read_csv(
                meta_csv_file_path,
                iterator=False,
                delimiter=csv_delimiter,
                escapechar='\\')
            target_col = csv_df.columns[csv_df.columns.str.contains(
                ':FILE')].to_list()
            if len(target_col) == 0:
                logger.error(
                    f'No column contains ":FILE" in {meta_csv_file_path}.')
                target_col = csv_df.columns[0]
            else:
                target_col = target_col[0]
            objects = csv_df[target_col].to_list()

            file_map[split] = objects
    # More general but low-efficiency.
    if not objects:
        objects = list_dataset_objects(
            hub_api=modelscope_api,
            max_limit=-1,
            is_recursive=True,
            dataset_name=dataset_name,
            namespace=namespace,
            version=revision)
        if contains_dir(file_map):
            file_map = get_split_objects_map(file_map, objects)

    return meta_map, file_map, args_map, custom_type_map
