# Copyright (c) Alibaba, Inc. and its affiliates.

import copy
import math
import os
from itertools import islice

import datasets
import pandas as pd
from datasets import IterableDataset
from tqdm.auto import tqdm

from modelscope.msdatasets.utils.maxcompute_utils import MaxComputeUtil
from modelscope.utils.constant import (DEFAULT_MAXCOMPUTE_ENDPOINT,
                                       EXTENSIONS_TO_LOAD, MaxComputeEnvs,
                                       VirgoDatasetConfig)
from modelscope.utils.logger import get_logger
from modelscope.utils.url_utils import fetch_csv_with_url, valid_url

logger = get_logger()


class ExternalDataset(object):
    """Dataset class for custom datasets."""

    def __init__(self, split_path_dict, config_kwargs):
        self.split_path_dict = split_path_dict
        self.config_kwargs = copy.deepcopy(config_kwargs)
        self.config_kwargs.update({'split_config': self.split_path_dict})
        # dataset for specific extensions
        self.spec_extension_dataset = None
        self.split_data_files = {
            k: []
            for k, _ in self.split_path_dict.items()
        }
        self.custom_map = {}

        # the extension of file
        file_ext = ''
        for split_name, split_dir in self.split_path_dict.items():
            if isinstance(split_dir, str) and os.path.isdir(split_dir):
                split_file_names = os.listdir(split_dir)
                set_files_exts = set([
                    os.path.splitext(file_name)[-1].strip('.')
                    for file_name in split_file_names
                ])
                if '' in set_files_exts:
                    continue
                # ensure these files have same extensions
                if len(set_files_exts) != 1:
                    supported_exts = ','.join(EXTENSIONS_TO_LOAD.keys())
                    logger.error(
                        f'Split-{split_name} has been ignored, please flatten your folder structure, '
                        f'and make sure these files have same extensions. '
                        f'Supported extensions: {supported_exts} .')
                    continue
                file_ext = list(set_files_exts)[0]
                if file_ext not in EXTENSIONS_TO_LOAD:
                    continue

                split_file_paths = [
                    os.path.join(split_dir, file_name)
                    for file_name in split_file_names
                ]
                self.split_data_files[split_name] = split_file_paths

        if file_ext:
            file_ext = EXTENSIONS_TO_LOAD.get(file_ext)
            self.spec_extension_dataset = datasets.load_dataset(
                file_ext, data_files=self.split_data_files, **config_kwargs)

    def __len__(self):
        return len(
            self.split_path_dict
        ) if not self.spec_extension_dataset else self.spec_extension_dataset.__len__(
        )

    def __getitem__(self, item):
        if not self.spec_extension_dataset:
            return self.split_path_dict.get(item)
        else:
            return self.spec_extension_dataset.__getitem__(item)

    def __iter__(self):
        if not self.spec_extension_dataset:
            for k, v in self.split_path_dict.items():
                yield k, v
        else:
            for k, v in self.spec_extension_dataset.items():
                yield k, v


class NativeIterableDataset(IterableDataset):
    """The modelscope iterable dataset class."""

    def __init__(self, ex_iterable, info, split, stream_batch_size=1):
        super().__init__(ex_iterable=ex_iterable, info=info, split=split)
        self.stream_batch_size = stream_batch_size

    def __iter__(self):
        for item in tqdm(
                self.iter(
                    batch_size=self.stream_batch_size, drop_last_batch=False),
                desc='Overall progress',
                total=self.n_shards,
                dynamic_ncols=True):
            ret = self._download_item(item)

            yield ret

    def __len__(self):
        return self.n_shards

    def __getitem__(self, index):
        """
        Returns the item at index `index` in the dataset. Slice indexing is supported.
        """
        if isinstance(index, int):
            start = index
            stop = index + 1
            step = None
        else:
            start = index.start
            stop = index.stop
            step = index.step

        if step is not None and step <= 0:
            raise ValueError('step must be positive')

        for item in tqdm(
                islice(
                    self.iter(batch_size=1, drop_last_batch=False), start,
                    stop, step),
                desc='Slicing progress',
                dynamic_ncols=True):
            ret = self._download_item(item)

            yield ret

    def _download_item(self, item):
        ret = {}
        if isinstance(item, dict):
            try:
                for k, v in item.items():
                    ret[k] = v
                    if k.endswith(':FILE'):
                        dl_manager = self._ex_iterable.kwargs.get('dl_manager')
                        ex_cache_path = dl_manager.download_and_extract(v)
                        if isinstance(ex_cache_path, str):
                            ex_cache_path = [ex_cache_path]
                        ret[k] = ex_cache_path
                        ret[k.strip(':FILE')] = v

            except Exception as e:
                logger.error(e)
                ret = item
        else:
            ret = item

        return ret

    def head(self, n=5):
        """
        Returns the first n rows of the dataset.

        Args:
            n (int): Number of rows to return.

        Returns:
            list: The list of results, e.g. [{'id': 'abc123', 'text': 'hello world'}, ...]
        """
        # return self._head(n=n)
        res = []
        if n <= 0:
            return res
        iter_num = 0
        for item in self.__iter__():
            if iter_num >= n:
                break
            res.append(item)
            iter_num += 1
        return res


class VirgoDataset(object):
    """Dataset class for Virgo.

    Attributes:
        _meta_content (str): Virgo meta data content, could be a url that contains csv file.
        _data_type (int): Virgo dataset type, 0-Standard virgo dataset; Others-User define dataset (to be supported)

    Examples:
        >>> from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset
        >>> input_kwargs = {'metaContent': 'http://xxx-xxx/xxx.csv', 'samplingType': 0}
        >>> virgo_dataset = VirgoDataset(**input_kwargs)
        >>> print(virgo_dataset[1])
        >>> print(len(virgo_dataset))
        >>> for line in virgo_dataset:
        >>>     print(line)

        Note: If you set `download_virgo_files` to True by using
            MsDataset.load(dataset_name='your-virgo-dataset-id', hub=Hubs.virgo, download_virgo_files=True),
            you can get the cache file path of the virgo dataset, the column name is `cache_file`.
        >>> if virgo_dataset.download_virgo_files:
        >>>     print(virgo_dataset[1].get('cache_file'))
    """

    def __init__(self, **kwargs):

        self._meta_content: str = ''
        self.data_type: int = 0
        self.odps_table_name: str = ''
        self.odps_table_partition: str = None
        self._odps_utils: MaxComputeUtil = None
        self.config_kwargs = kwargs

        self._meta: pd.DataFrame = pd.DataFrame()

        self._meta_content = self.config_kwargs.pop(
            VirgoDatasetConfig.meta_content, '')
        self.data_type = self.config_kwargs.pop(
            VirgoDatasetConfig.sampling_type, 0)

        self._check_variables()
        self._parse_meta()

        self.meta_content_cache_file = ''
        self.virgo_cache_dir = ''
        self.download_virgo_files: bool = False

        self.odps_table_ins = None
        self.odps_reader_ins = None
        self.odps_batch_size = self.config_kwargs.pop('odps_batch_size', 100)
        self.odps_limit = self.config_kwargs.pop('odps_limit', None)
        self.odps_drop_last = self.config_kwargs.pop('odps_drop_last', False)
        if self._odps_utils:
            self.odps_table_ins, self.odps_reader_ins = self._odps_utils.get_table_reader_ins(
                self.odps_table_name, self.odps_table_partition)

    def __getitem__(self, index):
        if self.odps_reader_ins:
            return MaxComputeUtil.gen_reader_item(
                reader=self.odps_reader_ins,
                index=index,
                batch_size_in=self.odps_batch_size,
                limit_in=self.odps_limit,
                drop_last_in=self.odps_drop_last,
                partitions=self.odps_table_ins.table_schema.partitions,
                columns=self.odps_table_ins.table_schema.names)
        return self._meta.iloc[index].to_dict()

    def __len__(self):
        if isinstance(self._meta, dict):
            return self._meta.get('odpsCount', 0)
        return len(self._meta)

    def __iter__(self):
        if self.odps_reader_ins:
            odps_batch_data = MaxComputeUtil.gen_reader_batch(
                reader=self.odps_reader_ins,
                batch_size_in=self.odps_batch_size,
                limit_in=self.odps_limit,
                drop_last_in=self.odps_drop_last,
                partitions=self.odps_table_ins.table_schema.partitions,
                columns=self.odps_table_ins.table_schema.names)
            for batch in odps_batch_data:
                yield batch
        else:
            for _, row in self._meta.iterrows():
                yield row.to_dict()

    @property
    def meta(self) -> pd.DataFrame:
        """
        Virgo meta data. Contains columns: id, meta_info, analysis_result, external_info and
            cache_file (if download_virgo_files is True).
        """
        return self._meta

    def _parse_meta(self):
        # Fetch csv content
        if isinstance(self._meta_content, str) and valid_url(
                self._meta_content):
            meta_content_df = fetch_csv_with_url(self._meta_content)
            self._meta = meta_content_df
        elif isinstance(self._meta_content, dict):
            self._meta = self._meta_content
            self.odps_table_name = self._meta.get('odpsTableName', '')
            self.odps_table_partition = self._meta.get('odpsTablePartition',
                                                       None)
            self._odps_utils = self._get_odps_info()
        else:
            raise 'The meta content must be url or dict.'

    @staticmethod
    def _get_odps_info() -> MaxComputeUtil:
        """
        Get MaxComputeUtil instance.

        Args:
            None

        Returns:
            MaxComputeUtil instance.
        """
        access_id = os.environ.get(MaxComputeEnvs.ACCESS_ID, '')
        access_key = os.environ.get(MaxComputeEnvs.ACCESS_SECRET_KEY, '')
        proj_name = os.environ.get(MaxComputeEnvs.PROJECT_NAME, '')
        endpoint = os.environ.get(MaxComputeEnvs.ENDPOINT,
                                  DEFAULT_MAXCOMPUTE_ENDPOINT)

        if not access_id or not access_key or not proj_name:
            raise ValueError(
                f'Please set MaxCompute envs for Virgo: {MaxComputeEnvs.ACCESS_ID}, '
                f'{MaxComputeEnvs.ACCESS_SECRET_KEY}, {MaxComputeEnvs.PROJECT_NAME}, '
                f'{MaxComputeEnvs.ENDPOINT}(default: http://service-corp.odps.aliyun-inc.com/api)'
            )

        return MaxComputeUtil(access_id, access_key, proj_name, endpoint)

    def _check_variables(self):
        """Check member variables in this class.
            1. Condition-1: self._meta_content cannot be empty
            2. Condition-2: self._meta_content must be url when self._data_type is 0
        """
        if not self._meta_content:
            raise 'Them meta content cannot be empty.'
        if self.data_type not in [0, 1]:
            raise 'Supported samplingType should be 0 or 1, others are not supported yet.'
        if self.data_type == 0 and not valid_url(self._meta_content):
            raise 'The meta content must be url when data type is 0.'
        if self.data_type == 1 and not isinstance(self._meta_content, dict):
            raise 'The meta content must be dict when data type is 1.'
