# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import random
from pathlib import Path
from typing import Any, Dict

import librosa
import soundfile as sf
import torch
from fairseq.data.audio.feature_transforms import \
    CompositeAudioFeatureTransform
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig

from modelscope.utils.chinese_utils import pre_chinese
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils.text2phone import Text2Phone


class OfaASRPreprocessor(OfaBasePreprocessor):

    def __init__(self,
                 cfg,
                 model_dir,
                 mode=ModeKeys.INFERENCE,
                 *args,
                 **kwargs):
        """preprocess the data

        Args:
            cfg(modelscope.utils.config.ConfigDict) : model config
            model_dir (str): model path,
            mode: preprocessor mode (model mode)
        """
        super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args,
                                                 **kwargs)
        # Initialize transform
        self.data_cfg = S2TDataConfig(
            Path(os.path.join(model_dir, 'fbank_config.yaml')))
        self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
            self.data_cfg.get_feature_transforms('train', True))
        self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
            self.data_cfg.get_feature_transforms('test', False))
        self.text2phone_tokenizer = Text2Phone(
            os.path.join(model_dir, 'text2phone_dict.txt'))
        self.phone_to_id, self.id_to_phone = self.build_phone_dict(
            os.path.join(model_dir, 'phone_dict.txt'))

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        if self.mode == ModeKeys.TRAIN:
            return self._build_train_sample(data)
        else:
            return self._build_infer_sample(data)

    def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
        speed = random.choice([0.9, 1.0, 1.1])
        audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
        wav, sr = librosa.load(audio_bytes, sr=16000, mono=True)
        fbank = self.prepare_fbank(
            torch.tensor([wav], dtype=torch.float32),
            sr,
            speed,
            target_sample_rate=16000,
            is_train=True)
        fbank_mask = torch.tensor([True])
        sample = {
            'fbank': fbank,
            'fbank_mask': fbank_mask,
            'label': data[self.column_map['text']]
        }

        target = sample['label']
        if self.language == 'zh':
            target = pre_chinese(target, self.max_tgt_length)
            sample['target'] = self.tokenize_text(target, add_bos=False)
        else:
            target = target.translate(self.transtab).strip()
            target_token_list = target.strip().split()
            target = ' '.join(target_token_list[:self.max_tgt_length])
            sample['target'] = self.tokenize_text(target, add_bos=False)

        phone_item = self.to_phone(target) + 1
        phone_mask = torch.tensor([False])

        sample['phone_item'] = phone_item + 3
        sample['phone_target'] = phone_item
        sample['phone_mask'] = phone_mask

        sample['prev_output_tokens'] = torch.cat(
            [self.bos_item, sample['target'][:-1]])
        return sample

    def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
        speed = 1.0
        audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
        wav, sr = librosa.load(audio_bytes, sr=16000, mono=True)
        fbank = self.prepare_fbank(
            torch.tensor([wav], dtype=torch.float32),
            sr,
            speed,
            target_sample_rate=16000,
            is_train=False)
        fbank_mask = torch.tensor([True])

        sample = {'fbank': fbank, 'fbank_mask': fbank_mask}

        if 'text' in self.column_map and self.column_map['text'] in data:
            sample['label'] = data[self.column_map['text']]

        # mock
        sample['phone_item'] = torch.tensor([6, 6, 6])
        sample['phone_mask'] = torch.tensor([False])

        return sample

    def to_phone(self, text):
        phones = self.text2phone_tokenizer.trans(text)
        ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')])
        return ids

    def build_phone_dict(self, phone_dict_path):
        phone_to_id = dict()
        id_to_phone = dict()
        with open(phone_dict_path, 'r') as phone_dict_file:
            for i, line in enumerate(phone_dict_file):
                phone = line.strip().split(' ')[0]
                phone_to_id[phone] = i
                id_to_phone[i] = phone_to_id
        return phone_to_id, id_to_phone
