# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
from typing import Dict, Optional, Union

import json
from funasr.bin import build_trainer

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
                                       DEFAULT_DATASET_REVISION,
                                       DEFAULT_MODEL_REVISION, ModelFile,
                                       Tasks, TrainerStages)
from modelscope.utils.logger import get_logger

logger = get_logger()


@TRAINERS.register_module(module_name=Trainers.speech_asr_trainer)
class ASRTrainer(BaseTrainer):
    DATA_DIR = 'data'

    def __init__(self,
                 model: str,
                 work_dir: str = None,
                 distributed: bool = False,
                 dataset_type: str = 'small',
                 data_dir: Optional[Union[MsDataset, str]] = None,
                 model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
                 batch_bins: Optional[int] = None,
                 max_epoch: Optional[int] = None,
                 lr: Optional[float] = None,
                 mate_params: Optional[dict] = None,
                 **kwargs):
        """ASR Trainer.

        Args:
            model (str) : model name
            work_dir (str): output dir for saving results
            distributed (bool): whether to enable DDP training
            dataset_type (str): choose which dataset type to use
            data_dir (str): the path of data
            model_revision (str): set model version
            batch_bins (str): batch size
            max_epoch (int): the maximum epoch number for training
            lr (float): learning rate
            mate_params (dict): for saving other training args
        Examples:

        >>> import os
        >>> from modelscope.metainfo import Trainers
        >>> from modelscope.msdatasets import MsDataset
        >>> from modelscope.trainers import build_trainer
        >>> ds_dict = MsDataset.load('speech_asr_aishell1_trainsets')
        >>> kwargs = dict(
        >>>     model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
        >>>     data_dir=ds_dict,
        >>>     work_dir="./checkpoint")
        >>> trainer = build_trainer(
        >>>     Trainers.speech_asr_trainer, default_args=kwargs)
        >>> trainer.train()

        """
        if not work_dir:
            self.work_dir = tempfile.TemporaryDirectory().name
            if not os.path.exists(self.work_dir):
                os.makedirs(self.work_dir)
        else:
            self.work_dir = work_dir

        if not os.path.exists(self.work_dir):
            raise Exception(f'{self.work_dir} not exists')

        logger.info(f'Set workdir to {self.work_dir}')

        self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
        self.raw_dataset_path = ''
        self.distributed = distributed
        self.dataset_type = dataset_type

        shutil.rmtree(self.data_dir, ignore_errors=True)

        os.makedirs(self.data_dir, exist_ok=True)

        if os.path.exists(model):
            model_dir = model
        else:
            model_dir = self.get_or_download_model_dir(model, model_revision)
        self.model_dir = model_dir
        self.model_cfg = os.path.join(self.model_dir, 'configuration.json')
        self.cfg_dict = self.parse_cfg(self.model_cfg)

        if 'raw_data_dir' not in data_dir:
            self.train_data_dir, self.dev_data_dir = self.load_dataset_raw_path(
                data_dir, self.data_dir)
        else:
            self.data_dir = data_dir['raw_data_dir']
        self.trainer = build_trainer.build_trainer(
            modelscope_dict=self.cfg_dict,
            data_dir=self.data_dir,
            output_dir=self.work_dir,
            distributed=self.distributed,
            dataset_type=self.dataset_type,
            batch_bins=batch_bins,
            max_epoch=max_epoch,
            lr=lr,
            mate_params=mate_params)

    def parse_cfg(self, cfg_file):
        cur_dir = os.path.dirname(cfg_file)
        cfg_dict = dict()
        with open(cfg_file, 'r', encoding='utf-8') as f:
            config = json.load(f)
            cfg_dict['mode'] = config['model']['model_config']['mode']
            cfg_dict['model_dir'] = cur_dir
            cfg_dict['am_model_file'] = os.path.join(
                cur_dir, config['model']['am_model_name'])
            cfg_dict['am_model_config'] = os.path.join(
                cur_dir, config['model']['model_config']['am_model_config'])
            cfg_dict['finetune_config'] = os.path.join(cur_dir,
                                                       'finetune.yaml')
            cfg_dict['cmvn_file'] = os.path.join(
                cur_dir, config['model']['model_config']['mvn_file'])
            cfg_dict['seg_dict'] = os.path.join(cur_dir, 'seg_dict')
            if 'bpemodel' in config['model']['model_config']:
                cfg_dict['bpemodel'] = os.path.join(
                    cur_dir, config['model']['model_config']['bpemodel'])
            else:
                cfg_dict['bpemodel'] = None
            if 'init_model' in config['model']['model_config']:
                cfg_dict['init_model'] = os.path.join(
                    cur_dir, config['model']['model_config']['init_model'])
            else:
                cfg_dict['init_model'] = cfg_dict['am_model_file']
        return cfg_dict

    def load_dataset_raw_path(self, dataset, output_data_dir):
        if 'train' not in dataset:
            raise Exception(
                'dataset {0} does not contain a train split'.format(dataset))
        train_data_dir = self.prepare_data(
            dataset, output_data_dir, split='train')
        if 'validation' not in dataset:
            raise Exception(
                'dataset {0} does not contain a dev split'.format(dataset))
        dev_data_dir = self.prepare_data(
            dataset, output_data_dir, split='validation')
        return train_data_dir, dev_data_dir

    def prepare_data(self, dataset, out_base_dir, split='train'):
        out_dir = os.path.join(out_base_dir, split)
        shutil.rmtree(out_dir, ignore_errors=True)
        os.makedirs(out_dir, exist_ok=True)
        data_cnt = len(dataset[split])
        fp_wav_scp = open(os.path.join(out_dir, 'wav.scp'), 'w')
        fp_text = open(os.path.join(out_dir, 'text'), 'w')
        for i in range(data_cnt):
            content = dataset[split][i]
            wav_file = content['Audio:FILE']
            text = content['Text:LABEL']
            fp_wav_scp.write('\t'.join([os.path.basename(wav_file), wav_file])
                             + '\n')
            fp_text.write('\t'.join([os.path.basename(wav_file), text]) + '\n')
        fp_text.close()
        fp_wav_scp.close()
        return out_dir

    def train(self, *args, **kwargs):
        self.trainer.run()

    def evaluate(self, checkpoint_path: str, *args,
                 **kwargs) -> Dict[str, float]:
        raise NotImplementedError
