# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import warnings
from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
                    Sequence, Union)

import numpy as np
from datasets import (Dataset, DatasetDict, Features, IterableDataset,
                      IterableDatasetDict)
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
from datasets.utils.file_utils import is_relative_path

from modelscope.hub.repository import DatasetRepository
from modelscope.msdatasets.context.dataset_context_config import \
    DatasetContextConfig
from modelscope.msdatasets.data_loader.data_loader_manager import (
    LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager,
    RemoteDataLoaderType)
from modelscope.msdatasets.dataset_cls import (ExternalDataset,
                                               NativeIterableDataset)
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
    build_custom_dataset
from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager
from modelscope.msdatasets.utils.hf_datasets_util import load_dataset_with_ctx
from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
                                       DEFAULT_DATASET_REVISION,
                                       REPO_TYPE_DATASET, ConfigFields,
                                       DatasetFormations, DownloadMode, Hubs,
                                       ModeKeys, Tasks, UploadMode)
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger

logger = get_logger()


def format_list(para) -> List:
    if para is None:
        para = []
    elif isinstance(para, str):
        para = [para]
    elif len(set(para)) < len(para):
        raise ValueError(f'List columns contains duplicates: {para}')
    return para


class MsDataset:
    """
    ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to
    provide efficient data access and local storage managements. On top of
    that, MsDataset supports the data integration and interactions with multiple
    remote hubs, particularly, ModelScope's own Dataset-hub. MsDataset also
    abstracts away data-access details with other remote storage, including both
    general external web-hosted data and cloud storage such as OSS.
    """
    # the underlying huggingface Dataset
    _hf_ds = None
    _dataset_context_config: DatasetContextConfig = None

    def __init__(self,
                 ds_instance: Union[Dataset, IterableDataset, ExternalDataset,
                                    NativeIterableDataset],
                 target: Optional[str] = None):
        self._hf_ds = ds_instance
        if target is not None and target not in self._hf_ds.features:
            raise TypeError(
                f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}'
            )
        self.target = target
        self.is_custom = False

    def __iter__(self):
        for item in self._hf_ds:
            if self.target is not None:
                yield item[self.target]
            else:
                yield item

    def __getitem__(self, key):
        return self._hf_ds[key]

    def __len__(self):
        return len(self._hf_ds)

    @property
    def ds_instance(self):
        return self._hf_ds

    @property
    def config_kwargs(self):
        if isinstance(self._hf_ds, ExternalDataset):
            return self._hf_ds.config_kwargs
        else:
            return None

    @classmethod
    def from_hf_dataset(cls,
                        hf_ds: Union[Dataset, DatasetDict, ExternalDataset],
                        target: str = None) -> Union[dict, 'MsDataset']:
        r"""
        @deprecated
        This method is deprecated and may be removed in future releases, please use `to_ms_dataset()` instead.
        """
        warnings.warn(
            'from_hf_dataset is deprecated, please use to_ms_dataset instead.',
            DeprecationWarning)
        if isinstance(hf_ds, Dataset):
            return cls(hf_ds, target)
        elif isinstance(hf_ds, DatasetDict):
            if len(hf_ds.keys()) == 1:
                return cls(next(iter(hf_ds.values())), target)
            return {k: cls(v, target) for k, v in hf_ds.items()}
        elif isinstance(hf_ds, ExternalDataset):
            return cls(hf_ds)
        else:
            raise TypeError(
                f'"hf_ds" must be a Dataset or DatasetDict, but got {type(hf_ds)}'
            )

    @classmethod
    def to_ms_dataset(cls,
                      ds_instance: Union[Dataset, DatasetDict, ExternalDataset,
                                         NativeIterableDataset,
                                         IterableDataset, IterableDatasetDict],
                      target: str = None) -> Union[dict, 'MsDataset']:
        """Convert input to `MsDataset` instance."""
        if isinstance(ds_instance, Dataset):
            return cls(ds_instance, target)
        elif isinstance(ds_instance, DatasetDict):
            if len(ds_instance.keys()) == 1:
                return cls(next(iter(ds_instance.values())), target)
            return {k: cls(v, target) for k, v in ds_instance.items()}
        elif isinstance(ds_instance, ExternalDataset):
            return cls(ds_instance)
        elif isinstance(ds_instance, NativeIterableDataset):
            return cls(ds_instance)
        elif isinstance(ds_instance, IterableDataset):
            return cls(ds_instance)
        elif isinstance(ds_instance, IterableDatasetDict):
            if len(ds_instance.keys()) == 1:
                return cls(next(iter(ds_instance.values())), target)
            return {k: cls(v, target) for k, v in ds_instance.items()}
        else:
            raise TypeError(
                f'"ds_instance" must be a Dataset or DatasetDict, but got {type(ds_instance)}'
            )

    @staticmethod
    def load(
        dataset_name: Union[str, list],
        namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE,
        target: Optional[str] = None,
        version: Optional[str] = DEFAULT_DATASET_REVISION,
        hub: Optional[Hubs] = Hubs.modelscope,
        subset_name: Optional[str] = None,
        split: Optional[str] = None,
        data_dir: Optional[str] = None,
        data_files: Optional[Union[str, Sequence[str],
                                   Mapping[str, Union[str,
                                                      Sequence[str]]]]] = None,
        download_mode: Optional[DownloadMode] = DownloadMode.
        REUSE_DATASET_IF_EXISTS,
        cache_dir: Optional[str] = MS_DATASETS_CACHE,
        features: Optional[Features] = None,
        use_streaming: Optional[bool] = False,
        stream_batch_size: Optional[int] = 1,
        custom_cfg: Optional[Config] = Config(),
        token: Optional[str] = None,
        dataset_info_only: Optional[bool] = False,
        trust_remote_code: Optional[bool] = False,
        **config_kwargs,
    ) -> Union[dict, 'MsDataset', NativeIterableDataset]:
        """Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.

            Args:
                dataset_name (str): Path or name of the dataset.
                    The form of `namespace/dataset_name` is also supported.
                namespace(str, optional): Namespace of the dataset. It should not be None if you load a remote dataset
                    from Hubs.modelscope,
                namespace (str, optional):
                    Namespace of the dataset. It should not be None if you load a remote dataset
                    from Hubs.modelscope,
                target (str, optional): Name of the column to output.
                version (str, optional): Version of the dataset script to load:
                subset_name (str, optional): Defining the subset_name of the dataset.
                data_dir (str, optional): Defining the data_dir of the dataset configuration. I
                data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s).
                split (str, optional): Which split of the data to load.
                hub (Hubs or str, optional): When loading from a remote hub, where it is from. default Hubs.modelscope
                download_mode (DownloadMode or str, optional): How to treat existing datasets. default
                                                               DownloadMode.REUSE_DATASET_IF_EXISTS
                cache_dir (str, Optional): User-define local cache directory.
                use_streaming (bool, Optional): If set to True, no need to download all data files.
                                                Instead, it streams the data progressively, and returns
                                                NativeIterableDataset or a dict of NativeIterableDataset.
                stream_batch_size (int, Optional): The batch size of the streaming data.
                custom_cfg (str, Optional): Model configuration, this can be used for custom datasets.
                                           see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3
                token (str, Optional): SDK token of ModelScope.
                dataset_info_only (bool, Optional): If set to True, only return the dataset config and info (dict).
                trust_remote_code (bool, Optional): If set to True, trust the remote code. Default to `False`.
                **config_kwargs (additional keyword arguments): Keyword arguments to be passed

            Returns:
                MsDataset (MsDataset): MsDataset object for a certain dataset.
            """

        if token:
            from modelscope.hub.api import HubApi
            api = HubApi()
            api.login(token)

        download_mode = DownloadMode(download_mode
                                     or DownloadMode.REUSE_DATASET_IF_EXISTS)
        hub = Hubs(hub or Hubs.modelscope)
        is_huggingface_hub = (hub == Hubs.huggingface)
        if not isinstance(dataset_name, str) and not isinstance(
                dataset_name, list):
            raise TypeError(
                f'dataset_name must be `str` or `list`, but got {type(dataset_name)}'
            )

        if isinstance(dataset_name, list):
            if target is None:
                target = 'target'
            dataset_inst = Dataset.from_dict({target: dataset_name})
            return MsDataset.to_ms_dataset(dataset_inst, target=target)

        dataset_name = os.path.expanduser(dataset_name)
        is_local_path = os.path.exists(dataset_name)
        if is_relative_path(dataset_name) and dataset_name.count(
                '/') == 1 and not is_local_path and not is_huggingface_hub:
            dataset_name_split = dataset_name.split('/')
            namespace = dataset_name_split[0].strip()
            dataset_name = dataset_name_split[1].strip()
            if not namespace or not dataset_name:
                raise 'The dataset_name should be in the form of `namespace/dataset_name` or `dataset_name`.'

        if trust_remote_code:
            logger.warning(
                f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make sure that '
                'you can trust the external codes.')

        # Init context config
        dataset_context_config = DatasetContextConfig(
            dataset_name=dataset_name,
            namespace=namespace,
            version=version,
            subset_name=subset_name,
            split=split,
            target=target,
            hub=hub,
            data_dir=data_dir,
            data_files=data_files,
            download_mode=download_mode,
            cache_root_dir=cache_dir,
            use_streaming=use_streaming,
            stream_batch_size=stream_batch_size,
            trust_remote_code=trust_remote_code,
            **config_kwargs)

        # Load from local disk
        if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(
                dataset_name) or os.path.isfile(dataset_name):
            dataset_inst = LocalDataLoaderManager(
                dataset_context_config).load_dataset(
                    LocalDataLoaderType.HF_DATA_LOADER)
            dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
            if isinstance(dataset_inst, MsDataset):
                dataset_inst._dataset_context_config = dataset_context_config
                if custom_cfg:
                    dataset_inst.to_custom_dataset(
                        custom_cfg=custom_cfg, **config_kwargs)
                    dataset_inst.is_custom = True
            return dataset_inst
        # Load from the huggingface hub
        elif hub == Hubs.huggingface:
            from datasets import load_dataset
            return load_dataset(
                dataset_name,
                name=subset_name,
                split=split,
                streaming=use_streaming,
                download_mode=download_mode.value,
                trust_remote_code=trust_remote_code,
                **config_kwargs)

        # Load from the modelscope hub
        elif hub == Hubs.modelscope:
            # Get dataset type from ModelScope Hub;  dataset_type->4: General Dataset
            from modelscope.hub.api import HubApi
            _api = HubApi()
            endpoint = _api.get_endpoint_for_read(
                repo_id=namespace + '/' + dataset_name,
                repo_type=REPO_TYPE_DATASET)
            dataset_id_on_hub, dataset_type = _api.get_dataset_id_and_type(
                dataset_name=dataset_name,
                namespace=namespace,
                endpoint=endpoint)

            # Load from the ModelScope Hub for type=4 (general)
            if str(dataset_type) == str(DatasetFormations.general.value):

                with load_dataset_with_ctx(
                        path=namespace + '/' + dataset_name,
                        name=subset_name,
                        data_dir=data_dir,
                        data_files=data_files,
                        split=split,
                        cache_dir=cache_dir,
                        features=features,
                        download_config=None,
                        download_mode=download_mode.value,
                        revision=version,
                        token=token,
                        streaming=use_streaming,
                        dataset_info_only=dataset_info_only,
                        trust_remote_code=trust_remote_code,
                        **config_kwargs) as dataset_res:

                    return dataset_res

            else:

                remote_dataloader_manager = RemoteDataLoaderManager(
                    dataset_context_config)
                dataset_inst = remote_dataloader_manager.load_dataset(
                    RemoteDataLoaderType.MS_DATA_LOADER)
                dataset_inst = MsDataset.to_ms_dataset(
                    dataset_inst, target=target)
                if isinstance(dataset_inst, MsDataset):
                    dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
                    if custom_cfg:
                        dataset_inst.to_custom_dataset(
                            custom_cfg=custom_cfg, **config_kwargs)
                        dataset_inst.is_custom = True
                return dataset_inst

        elif hub == Hubs.virgo:
            warnings.warn(
                'The option `Hubs.virgo` is deprecated, '
                'will be removed in the future version.', DeprecationWarning)
            from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
            from modelscope.utils.constant import VirgoDatasetConfig
            # Rewrite the namespace, version and cache_dir for virgo dataset.
            if namespace == DEFAULT_DATASET_NAMESPACE:
                dataset_context_config.namespace = VirgoDatasetConfig.default_virgo_namespace
            if version == DEFAULT_DATASET_REVISION:
                dataset_context_config.version = VirgoDatasetConfig.default_dataset_version
            if cache_dir == MS_DATASETS_CACHE:
                from modelscope.utils.config_ds import CACHE_HOME
                cache_dir = os.path.join(CACHE_HOME, 'virgo', 'hub',
                                         'datasets')
                dataset_context_config.cache_root_dir = cache_dir

            virgo_downloader = VirgoDownloader(dataset_context_config)
            virgo_downloader.process()

            return virgo_downloader.dataset

        else:
            raise 'Please adjust input args to specify a loading mode, we support following scenes: ' \
                  'loading from local disk, huggingface hub and modelscope hub.'

    @staticmethod
    def upload(
            object_name: str,
            local_file_path: str,
            dataset_name: str,
            namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE,
            version: Optional[str] = DEFAULT_DATASET_REVISION,
            num_processes: Optional[int] = None,
            chunksize: Optional[int] = 1,
            filter_hidden_files: Optional[bool] = True,
            upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None:
        r"""
        @deprecated
        This method is deprecated and may be removed in future releases, please use git command line instead.
        """
        """Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first.

        Args:
            object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name
            local_file_path (str): Local file or directory to upload
            dataset_name (str): Name of the dataset
            namespace(str, optional): Namespace of the dataset
            version: Optional[str]: Version of the dataset
            num_processes: Optional[int]: The number of processes used for multiprocess uploading.
                This is only applicable when local_file_path is a directory, and we are uploading mutliple-files
                inside the directory. When None provided, the number returned by os.cpu_count() is used as default.
            chunksize: Optional[int]: The chunksize of objects to upload.
                For very long iterables using a large value for chunksize can make the job complete much faster than
                using the default value of 1. Available if local_file_path is a directory.
            filter_hidden_files: Optional[bool]: Whether to filter hidden files.
                Available if local_file_path is a directory.
            upload_mode: Optional[UploadMode]: How to upload objects from local.  Default: UploadMode.OVERWRITE, upload
                all objects from local, existing remote objects may be overwritten.

        Returns:
            None

        """
        warnings.warn(
            'The function `upload` is deprecated, '
            'please use git command '
            'or modelscope.hub.api.HubApi.upload_folder '
            'or modelscope.hub.api.HubApi.upload_file.', DeprecationWarning)

        if not object_name:
            raise ValueError('object_name cannot be empty!')

        _upload_manager = DatasetUploadManager(
            dataset_name=dataset_name, namespace=namespace, version=version)

        upload_mode = UploadMode(upload_mode or UploadMode.OVERWRITE)

        if os.path.isfile(local_file_path):
            _upload_manager.upload(
                object_name=object_name,
                local_file_path=local_file_path,
                upload_mode=upload_mode)
        elif os.path.isdir(local_file_path):
            _upload_manager.upload_dir(
                object_dir_name=object_name,
                local_dir_path=local_file_path,
                num_processes=num_processes,
                chunksize=chunksize,
                filter_hidden_files=filter_hidden_files,
                upload_mode=upload_mode)
        else:
            raise ValueError(
                f'{local_file_path} is not a valid file path or directory')

    @staticmethod
    def clone_meta(dataset_work_dir: str,
                   dataset_id: str,
                   revision: Optional[str] = DEFAULT_DATASET_REVISION,
                   auth_token: Optional[str] = None,
                   git_path: Optional[str] = None) -> None:
        """Clone meta-file of dataset from the ModelScope Hub.

        Args:
            dataset_work_dir (str): Current git working directory.
            dataset_id (str): Dataset id, in the form of your-namespace/your-dataset-name .
            revision (str, optional):
                revision of the model you want to clone from. Can be any of a branch, tag or commit hash
            auth_token (str, optional):
                token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
                as the token is already saved when you login the first time, if None, we will use saved token.
            git_path (str, optional):
                The git command line path, if None, we use 'git'
        Returns:
            None
        """

        warnings.warn(
            'The function `clone_meta` is deprecated, please use git command line to clone the repo.',
            DeprecationWarning)

        _repo = DatasetRepository(
            repo_work_dir=dataset_work_dir,
            dataset_id=dataset_id,
            revision=revision,
            auth_token=auth_token,
            git_path=git_path)
        clone_work_dir = _repo.clone()
        if clone_work_dir:
            logger.info('Already cloned repo to: {}'.format(clone_work_dir))
        else:
            logger.warning(
                'Repo dir already exists: {}'.format(clone_work_dir))

    @staticmethod
    def upload_meta(dataset_work_dir: str,
                    commit_message: str,
                    revision: Optional[str] = DEFAULT_DATASET_REVISION,
                    auth_token: Optional[str] = None,
                    git_path: Optional[str] = None,
                    force: bool = False) -> None:
        """Upload meta-file of dataset to the ModelScope Hub. Please clone the meta-data from the ModelScope Hub first.

        Args:
            dataset_work_dir (str): Current working directory.
            commit_message (str): Commit message.
            revision(`Optional[str]`):
                revision of the model you want to clone from. Can be any of a branch, tag or commit hash
            auth_token(`Optional[str]`):
                token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
                as the token is already saved when you log in the first time, if None, we will use saved token.
            git_path:(`Optional[str]`):
                The git command line path, if None, we use 'git'
            force (Optional[bool]): whether to use forced-push.

        Returns:
            None

        """
        warnings.warn(
            'The function `upload_meta` is deprecated, '
            'please use git command '
            'or CLI `modelscope upload owner_name/repo_name ...`.',
            DeprecationWarning)

        _repo = DatasetRepository(
            repo_work_dir=dataset_work_dir,
            dataset_id='',
            revision=revision,
            auth_token=auth_token,
            git_path=git_path)
        _repo.push(commit_message=commit_message, branch=revision, force=force)

    @staticmethod
    def delete(object_name: str,
               dataset_name: str,
               namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE,
               version: Optional[str] = DEFAULT_DATASET_REVISION) -> str:
        """ Delete object of dataset. Please log in first and make sure you have permission to manage the dataset.

        Args:
            object_name (str): The object name of dataset to be deleted. Could be a name of file or directory. If it's
                directory, then ends with `/`.
                For example: your-data-name.zip, train/001/img_001.png, train/, ...
            dataset_name (str): Path or name of the dataset.
            namespace(str, optional): Namespace of the dataset.
            version (str, optional): Version of the dataset.

        Returns:
            res_msg (str): Response message.

        """
        _delete_manager = DatasetDeleteManager(
            dataset_name=dataset_name, namespace=namespace, version=version)
        resp_msg = _delete_manager.delete(object_name=object_name)
        logger.info(f'Object {object_name} successfully removed!')
        return resp_msg

    def to_torch_dataset(
        self,
        columns: Union[str, List[str]] = None,
        preprocessors: Union[Callable, List[Callable]] = None,
        task_name: str = None,
        data_config: ConfigDict = None,
        to_tensor: bool = True,
        **format_kwargs,
    ):
        """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
           torch.utils.data.DataLoader.

        Args:
            preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
                every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict
                will be used as a field of torch.utils.data.Dataset.
            columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if
                `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column.
                If the `preprocessors` is not None, the output fields of processors will also be added.
            task_name (str, default None):  task name, refer to :obj:`Tasks` for more details
            data_config (ConfigDict, default None): config dict for model object.
                Attributes of ConfigDict:
                    `preprocessor` (Callable, List[Callable], optional): preprocessors to deal with dataset
                    `type` (str): the type of task
                    `split_config` (dict, optional): get the split config for ExternalDataset
                    `test_mode` (bool, optional): is test mode or not
            to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not.
            format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.

        Returns:
            :class:`torch.utils.data.Dataset`

        """
        if not is_torch_available():
            raise ImportError(
                'The function to_torch_dataset requires pytorch to be installed'
            )
        if isinstance(self._hf_ds, ExternalDataset):
            data_config.update({'preprocessor': preprocessors})
            data_config.update(self._hf_ds.config_kwargs)
            return build_custom_dataset(data_config, task_name)
        if preprocessors is not None:
            return self._to_torch_dataset_with_processors(
                preprocessors, columns=columns, to_tensor=to_tensor)
        else:
            self._hf_ds.reset_format()
            self._hf_ds.set_format(
                type='torch', columns=columns, format_kwargs=format_kwargs)
            return self._hf_ds

    def to_tf_dataset(
        self,
        batch_size: int,
        shuffle: bool,
        preprocessors: Union[Callable, List[Callable]] = None,
        columns: Union[str, List[str]] = None,
        collate_fn: Callable = None,
        drop_remainder: bool = None,
        collate_fn_args: Dict[str, Any] = None,
        label_cols: Union[str, List[str]] = None,
        prefetch: bool = True,
    ):
        """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like
           model.fit() or model.predict().

        Args:
            batch_size (int): Number of samples in a single batch.
            shuffle(bool): Shuffle the dataset order.
            preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
                every sample of the dataset. The output type of processors is dict, and each field of the dict will be
                used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn`
                shouldn't be None.
            columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None,
                the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of
                processors will also be added.
            collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If
                the `preprocessors` is None, the `collate_fn` shouldn't be None.
            drop_remainder(bool, default None): Drop the last incomplete batch when loading.
            collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`.
            label_cols (str or List[str], default None): Dataset column(s) to load as labels.
            prefetch (bool, default True): Prefetch data.

        Returns:
            :class:`tf.data.Dataset`

        """
        if not is_tf_available():
            raise ImportError(
                'The function to_tf_dataset requires Tensorflow to be installed.'
            )
        if preprocessors is not None:
            return self._to_tf_dataset_with_processors(
                batch_size,
                shuffle,
                preprocessors,
                drop_remainder=drop_remainder,
                prefetch=prefetch,
                label_cols=label_cols,
                columns=columns)

        if collate_fn is None:
            logger.error(
                'The `preprocessors` and the `collate_fn` should`t be both None.'
            )
            return None
        self._hf_ds.reset_format()
        return self._hf_ds.to_tf_dataset(
            columns,
            batch_size,
            shuffle,
            collate_fn,
            drop_remainder=drop_remainder,
            collate_fn_args=collate_fn_args,
            label_cols=label_cols,
            prefetch=prefetch)

    def to_hf_dataset(self) -> Dataset:
        self._hf_ds.reset_format()
        return self._hf_ds

    def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset:
        """
        Rename columns and return the underlying hf dataset directly
        TODO: support native MsDataset column rename.
        Args:
            column_mapping: the mapping of the original and new column names
        Returns:
            underlying hf dataset
        """
        self._hf_ds.reset_format()
        return self._hf_ds.rename_columns(column_mapping)

    def _to_torch_dataset_with_processors(
        self,
        preprocessors: Union[Callable, List[Callable]],
        columns: Union[str, List[str]] = None,
        to_tensor: bool = True,
    ):
        preprocessor_list = preprocessors if isinstance(
            preprocessors, list) else [preprocessors]

        columns = format_list(columns)

        columns = [
            key for key in self._hf_ds.features.keys() if key in columns
        ]
        retained_numeric_columns = []
        retained_unumeric_columns = []
        if to_tensor:
            sample = next(iter(self._hf_ds))

            sample_res = {k: np.array(sample[k]) for k in columns}
            for processor in preprocessor_list:
                sample_res.update(
                    {k: np.array(v)
                     for k, v in processor(sample).items()})

            def is_numpy_number(value):
                return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
                    value.dtype, np.floating)

            for k in sample_res.keys():
                if not is_numpy_number(sample_res[k]):
                    logger.warning(
                        f'Data of column {k} is non-numeric, will be removed')
                    retained_unumeric_columns.append(k)
                    continue
                retained_numeric_columns.append(k)

        import torch

        class MsMapDataset(torch.utils.data.Dataset):

            def __init__(self, dataset: Iterable, preprocessor_list,
                         retained_numeric_columns, retained_unumeric_columns,
                         columns, to_tensor):
                super(MsDataset).__init__()
                self.dataset = dataset
                self.preprocessor_list = preprocessor_list
                self.to_tensor = to_tensor
                self.retained_numeric_columns = retained_numeric_columns
                self.retained_unumeric_columns = retained_unumeric_columns
                self.columns = columns

            def __len__(self):
                return len(self.dataset)

            def type_converter(self, x):
                if self.to_tensor:
                    return torch.as_tensor(x)
                else:
                    return x

            def __getitem__(self, index):
                item_dict = self.dataset[index]
                res = {
                    k: self.type_converter(item_dict[k])
                    for k in self.columns if (not self.to_tensor)
                    or k in self.retained_numeric_columns
                }
                for preprocessor in self.preprocessor_list:
                    for k, v in preprocessor(item_dict).items():
                        if (not self.to_tensor) or \
                                k in self.retained_numeric_columns:
                            res[k] = self.type_converter(v)
                        elif k in self.retained_unumeric_columns:
                            res[k] = v
                return res

        return MsMapDataset(self._hf_ds, preprocessor_list,
                            retained_numeric_columns,
                            retained_unumeric_columns, columns, to_tensor)

    def _to_tf_dataset_with_processors(
        self,
        batch_size: int,
        shuffle: bool,
        preprocessors: Union[Callable, List[Callable]],
        drop_remainder: bool = None,
        prefetch: bool = True,
        label_cols: Union[str, List[str]] = None,
        columns: Union[str, List[str]] = None,
    ):
        preprocessor_list = preprocessors if isinstance(
            preprocessors, list) else [preprocessors]

        label_cols = format_list(label_cols)
        columns = format_list(columns)
        cols_to_retain = list(set(label_cols + columns))
        retained_columns = [
            key for key in self._hf_ds.features.keys() if key in cols_to_retain
        ]
        import tensorflow as tf
        tf_dataset = tf.data.Dataset.from_tensor_slices(
            np.arange(len(self._hf_ds), dtype=np.int64))
        if shuffle:
            tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds))

        def func(i, return_dict=False):
            i = int(i)
            res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns}
            for preprocessor in preprocessor_list:
                # TODO preprocessor output may have the same key
                res.update({
                    k: np.array(v)
                    for k, v in preprocessor(self._hf_ds[i]).items()
                })
            if return_dict:
                return res
            return tuple(list(res.values()))

        sample_res = func(0, True)

        @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
        def fetch_function(i):
            output = tf.numpy_function(
                func,
                inp=[i],
                Tout=[
                    tf.dtypes.as_dtype(val.dtype)
                    for val in sample_res.values()
                ],
            )
            return {key: output[i] for i, key in enumerate(sample_res)}

        from tensorflow.data.experimental import AUTOTUNE
        tf_dataset = tf_dataset.map(
            fetch_function, num_parallel_calls=AUTOTUNE)
        if label_cols:

            def split_features_and_labels(input_batch):
                labels = {
                    key: tensor
                    for key, tensor in input_batch.items() if key in label_cols
                }
                if len(input_batch) == 1:
                    input_batch = next(iter(input_batch.values()))
                if len(labels) == 1:
                    labels = next(iter(labels.values()))
                return input_batch, labels

            tf_dataset = tf_dataset.map(split_features_and_labels)

        elif len(columns) == 1:
            tf_dataset = tf_dataset.map(lambda x: next(iter(x.values())))
        if batch_size > 1:
            tf_dataset = tf_dataset.batch(
                batch_size, drop_remainder=drop_remainder)

        if prefetch:
            tf_dataset = tf_dataset.prefetch(AUTOTUNE)
        return tf_dataset

    def to_custom_dataset(self,
                          custom_cfg: Config,
                          preprocessor=None,
                          mode=None,
                          **kwargs):
        """Convert the input datasets to specific custom datasets by given model configuration and preprocessor.

        Args:
            custom_cfg (Config): The model configuration for custom datasets.
            preprocessor (Preprocessor, Optional): Preprocessor for data samples.
            mode (str, Optional): See modelscope.utils.constant.ModeKeys

        Returns:
            `MsDataset`
        """

        if not is_torch_available():
            raise ImportError(
                'The function to_custom_dataset requires pytorch to be installed'
            )
        if not custom_cfg:
            return

        # Set the flag that it has been converted to custom dataset
        self.is_custom = True

        # Check mode
        if mode is None:
            if 'mode' in kwargs:
                mode = kwargs.get('mode')

        # Parse cfg
        ds_cfg_key = 'train' if mode == ModeKeys.TRAIN else 'val'
        data_cfg = custom_cfg.safe_get(f'dataset.{ds_cfg_key}')
        if data_cfg is None:
            data_cfg = ConfigDict(type=custom_cfg.model.type) if hasattr(
                custom_cfg, ConfigFields.model) else ConfigDict(type=None)
        data_cfg.update(dict(mode=mode))

        # Get preprocessors from custom_cfg
        task_name = custom_cfg.task
        if 'task' in kwargs:
            task_name = kwargs.pop('task')
        field_name = Tasks.find_field_by_task(task_name)
        if 'field' in kwargs:
            field_name = kwargs.pop('field')
        if preprocessor is None and hasattr(custom_cfg, 'preprocessor'):
            preprocessor_cfg = custom_cfg.preprocessor
            if preprocessor_cfg:
                preprocessor = build_preprocessor(preprocessor_cfg, field_name)

        # Build custom dataset
        if isinstance(self._hf_ds, ExternalDataset):
            data_cfg.update(dict(preprocessor=preprocessor))
            data_cfg.update(self._hf_ds.config_kwargs)
            self._hf_ds = build_custom_dataset(
                cfg=data_cfg, task_name=custom_cfg.task)
            return

        if preprocessor is not None:
            to_tensor = kwargs.get('to_tensor', True)
            self._hf_ds = self._to_torch_dataset_with_processors(
                preprocessors=preprocessor, to_tensor=to_tensor)
        else:
            self._hf_ds.reset_format()
            self._hf_ds.set_format(type='torch')
        return
