# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import List, Optional, Union

from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.logger import get_logger

logger = get_logger()


def is_config_has_model(cfg_file):
    try:
        cfg = Config.from_file(cfg_file)
        return hasattr(cfg, 'model') or hasattr(cfg, 'model_type')
    except Exception as e:
        logger.error(f'parse config file {cfg_file} failed: {e}')
        return False


def is_official_hub_path(path: Union[str, List],
                         revision: Optional[str] = DEFAULT_MODEL_REVISION):
    """ Whether path is an official hub name or a valid local
    path to official hub directory.
    """

    def is_official_hub_impl(path):
        if osp.exists(path):
            cfg_file = osp.join(path, ModelFile.CONFIGURATION)
            return osp.exists(cfg_file)
        else:
            try:
                _ = HubApi().get_model(path, revision=revision)
                return True
            except Exception as e:
                raise ValueError(f'invalid model repo path {e}')

    if isinstance(path, str):
        return is_official_hub_impl(path)
    else:
        results = [is_official_hub_impl(m) for m in path]
        all_true = all(results)
        any_true = any(results)
        if any_true and not all_true:
            raise ValueError(
                f'some model are hub address, some are not, model list: {path}'
            )

        return all_true


def is_model(path: Union[str, List]):
    """ whether path is a valid modelhub path and containing model config
    """

    def is_modelhub_path_impl(path):
        if osp.exists(path):
            cfg_file = osp.join(path, ModelFile.CONFIGURATION)
            hf_cfg_file = osp.join(path, ModelFile.CONFIG)
            if osp.exists(cfg_file):
                return is_config_has_model(cfg_file)
            elif osp.exists(hf_cfg_file):
                return is_config_has_model(hf_cfg_file)
            else:
                return False
        else:
            try:
                cfg_file = model_file_download(path, ModelFile.CONFIGURATION)
                if is_config_has_model(cfg_file):
                    return True
                else:
                    hf_cfg_file = model_file_download(path, ModelFile.CONFIG)
                    return is_config_has_model(hf_cfg_file)
            except Exception:
                return False

    if isinstance(path, str):
        return is_modelhub_path_impl(path)
    else:
        results = [is_modelhub_path_impl(m) for m in path]
        all_true = all(results)
        any_true = any(results)
        if any_true and not all_true:
            raise ValueError(
                f'some models are hub address, some are not, model list: {path}'
            )

        return all_true


def batch_process(model, data):
    import torch
    if model.__class__.__name__ == 'OfaForAllTasks':
        # collate batch data due to the nested data structure
        assert isinstance(data, list)
        batch_data = {
            'nsentences': len(data),
            'samples': [d['samples'][0] for d in data],
            'net_input': {}
        }
        for k in data[0]['net_input'].keys():
            batch_data['net_input'][k] = torch.cat(
                [d['net_input'][k] for d in data])
        if 'w_resize_ratios' in data[0]:
            batch_data['w_resize_ratios'] = torch.cat(
                [d['w_resize_ratios'] for d in data])
        if 'h_resize_ratios' in data[0]:
            batch_data['h_resize_ratios'] = torch.cat(
                [d['h_resize_ratios'] for d in data])

        return batch_data
