# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import zipfile
from typing import Callable, Dict, List, Optional, Tuple, Union

import json

from modelscope.metainfo import Preprocessors, Trainers
from modelscope.models import Model
from modelscope.models.audio.tts import SambertHifigan
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors.builder import build_preprocessor
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.audio.audio_utils import TtsTrainType
from modelscope.utils.audio.tts_exceptions import (
    TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException,
    TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException,
    TtsTrainingWorkDirNotExistsException)
from modelscope.utils.config import Config
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
                                       DEFAULT_DATASET_REVISION,
                                       DEFAULT_MODEL_REVISION, ModelFile,
                                       Tasks, TrainerStages)
from modelscope.utils.data_utils import to_device
from modelscope.utils.logger import get_logger

logger = get_logger()


@TRAINERS.register_module(module_name=Trainers.speech_kantts_trainer)
class KanttsTrainer(BaseTrainer):
    DATA_DIR = 'data'
    AM_TMP_DIR = 'tmp_am'
    VOC_TMP_DIR = 'tmp_voc'
    ORIG_MODEL_DIR = 'orig_model'

    def __init__(self,
                 model: Union[Model, str],
                 work_dir: str = None,
                 speaker: str = 'F7',
                 lang_type: str = 'PinYin',
                 cfg_file: str = None,
                 train_dataset: Union[MsDataset, str] = None,
                 train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE,
                 train_dataset_revision: str = DEFAULT_DATASET_REVISION,
                 train_type: dict = {
                     TtsTrainType.TRAIN_TYPE_SAMBERT: {},
                     TtsTrainType.TRAIN_TYPE_VOC: {}
                 },
                 preprocess_skip_script=False,
                 model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
                 **kwargs):

        if not work_dir:
            self.work_dir = tempfile.TemporaryDirectory().name
            if not os.path.exists(self.work_dir):
                os.makedirs(self.work_dir)
        else:
            self.work_dir = work_dir

        if not os.path.exists(self.work_dir):
            raise TtsTrainingWorkDirNotExistsException(
                f'{self.work_dir} not exists')

        self.train_type = dict()
        if isinstance(train_type, dict):
            for k, v in train_type.items():
                if (k == TtsTrainType.TRAIN_TYPE_SAMBERT
                        or k == TtsTrainType.TRAIN_TYPE_VOC
                        or k == TtsTrainType.TRAIN_TYPE_BERT):
                    self.train_type[k] = v

        if len(self.train_type) == 0:
            logger.info('train type empty, default to sambert and voc')
            self.train_type[TtsTrainType.TRAIN_TYPE_SAMBERT] = {}
            self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {}

        logger.info(f'Set workdir to {self.work_dir}')
        self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
        self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR)
        self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR)
        self.orig_model_dir = os.path.join(self.work_dir, self.ORIG_MODEL_DIR)
        self.raw_dataset_path = ''
        self.skip_script = preprocess_skip_script
        self.audio_config_path = ''
        self.am_config_path = ''
        self.voc_config_path = ''

        shutil.rmtree(self.data_dir, ignore_errors=True)
        shutil.rmtree(self.am_tmp_dir, ignore_errors=True)
        shutil.rmtree(self.voc_tmp_dir, ignore_errors=True)
        shutil.rmtree(self.orig_model_dir, ignore_errors=True)

        os.makedirs(self.data_dir)
        os.makedirs(self.am_tmp_dir)
        os.makedirs(self.voc_tmp_dir)

        if train_dataset:
            if isinstance(train_dataset, str):
                if os.path.exists(train_dataset):
                    logger.info(f'load {train_dataset}')
                    self.raw_dataset_path = train_dataset
                else:
                    logger.info(
                        f'load {train_dataset_namespace}/{train_dataset}')
                    train_dataset = MsDataset.load(
                        dataset_name=train_dataset,
                        namespace=train_dataset_namespace,
                        version=train_dataset_revision)
                    logger.info(f'train dataset:{train_dataset.config_kwargs}')
                    self.raw_dataset_path = self.load_dataset_raw_path(
                        train_dataset)
            else:
                self.raw_dataset_path = self.load_dataset_raw_path(
                    train_dataset)

        if not model:
            raise TtsTrainingInvalidModelException('model param is none')
        if isinstance(model, str):
            model_dir = self.get_or_download_model_dir(model, model_revision)
        else:
            model_dir = model.model_dir
        shutil.copytree(model_dir, self.orig_model_dir)
        self.model_dir = self.orig_model_dir

        if not cfg_file:
            cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
        self.parse_cfg(cfg_file)

        if not os.path.exists(self.raw_dataset_path):
            raise TtsTrainingDatasetInvalidException(
                'dataset raw path not exists')

        self.finetune_from_pretrain = False
        self.speaker = speaker
        self.model = None
        self.device = kwargs.get('device', 'gpu')
        self.model = self.get_model(self.model_dir, self.speaker)
        self.lang_type = self.model.lang_type
        if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
            self.audio_data_preprocessor = build_preprocessor(
                dict(type=Preprocessors.kantts_data_preprocessor),
                Tasks.text_to_speech)

    def parse_cfg(self, cfg_file):
        cur_dir = os.path.dirname(cfg_file)
        with open(cfg_file, 'r', encoding='utf-8') as f:
            config = json.load(f)
            if 'train' not in config:
                raise TtsTrainingInvalidModelException(
                    'model not support finetune')
            if 'audio_config' in config['train']:
                audio_config = os.path.join(cur_dir,
                                            config['train']['audio_config'])
                if os.path.exists(audio_config):
                    self.audio_config_path = audio_config
            if 'am_config' in config['train']:
                am_config = os.path.join(cur_dir, config['train']['am_config'])
                if os.path.exists(am_config):
                    self.am_config_path = am_config
            if 'voc_config' in config['train']:
                voc_config = os.path.join(cur_dir,
                                          config['train']['voc_config'])
                if os.path.exists(voc_config):
                    self.voc_config_path = voc_config
            if not self.raw_dataset_path:
                if 'train_dataset' in config['train']:
                    dataset = config['train']['train_dataset']
                    if os.path.exists(dataset):
                        self.raw_dataset_path = dataset
                    else:
                        if 'id' in dataset:
                            namespace = dataset.get('namespace',
                                                    DEFAULT_DATASET_NAMESPACE)
                            revision = dataset.get('revision',
                                                   DEFAULT_DATASET_REVISION)
                            ms = MsDataset.load(
                                dataset_name=dataset['id'],
                                namespace=namespace,
                                version=revision)
                            self.raw_dataset_path = self.load_dataset_raw_path(
                                ms)
                        elif 'path' in dataset:
                            self.raw_dataset_path = dataset['path']

    def load_dataset_raw_path(self, dataset: MsDataset):
        if 'split_config' not in dataset.config_kwargs:
            raise TtsTrainingDatasetInvalidException(
                'split_config not found in config_kwargs')
        if 'train' not in dataset.config_kwargs['split_config']:
            raise TtsTrainingDatasetInvalidException(
                'no train split in split_config')
        return dataset.config_kwargs['split_config']['train']

    def prepare_data(self):
        if self.audio_data_preprocessor:
            audio_config = self.audio_config_path
            if not audio_config or not os.path.exists(audio_config):
                audio_config = self.model.get_voice_audio_config_path(
                    self.speaker)
            se_model = self.model.get_voice_se_model_path(self.speaker)
            self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir,
                                         audio_config, self.speaker,
                                         self.lang_type, self.skip_script,
                                         se_model)

    def prepare_text(self):
        pass

    def get_model(self, model_dir, speaker):
        cfg = Config.from_file(
            os.path.join(self.model_dir, ModelFile.CONFIGURATION))
        model_cfg = cfg.get('model', {})
        model = SambertHifigan(
            model_dir=self.model_dir, is_train=True, **model_cfg)
        return model

    def train(self, *args, **kwargs):
        if not self.model:
            raise TtsTrainingInvalidModelException('model is none')
        ignore_pretrain = False
        if 'ignore_pretrain' in kwargs:
            ignore_pretrain = kwargs['ignore_pretrain']

        if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
            self.prepare_data()
        if TtsTrainType.TRAIN_TYPE_BERT in self.train_type:
            self.prepare_text()
        dir_dict = {
            'work_dir': self.work_dir,
            'am_tmp_dir': self.am_tmp_dir,
            'voc_tmp_dir': self.voc_tmp_dir,
            'data_dir': self.data_dir
        }
        config_dict = {
            'am_config': self.am_config_path,
            'voc_config': self.voc_config_path
        }
        self.model.train(self.speaker, dir_dict, self.train_type, config_dict,
                         ignore_pretrain)

    def evaluate(self, checkpoint_path: str, *args,
                 **kwargs) -> Dict[str, float]:
        return {}
