# Copyright (c) Alibaba, Inc. and its affiliates.

import csv
import os
from typing import Dict, Optional, Union

import numpy as np
import speechbrain as sb
import speechbrain.nnet.schedulers as schedulers
import torch
import torch.nn.functional as F
import torchaudio
from torch.cuda.amp import autocast
from torch.utils.data import Dataset
from tqdm import tqdm

from modelscope.metainfo import Trainers
from modelscope.models import Model, TorchModel
from modelscope.msdatasets import MsDataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
                                          init_dist)

EVAL_KEY = 'si-snr'

logger = get_logger()


@TRAINERS.register_module(module_name=Trainers.speech_separation)
class SeparationTrainer(BaseTrainer):
    """A trainer is used for speech separation.

    Args:
        model: id or local path of the model
        work_dir: local path to store all training outputs
        cfg_file: config file of the model
        train_dataset: dataset for training
        eval_dataset: dataset for evaluation
        model_revision: the git version of model on modelhub
    """

    def __init__(self,
                 model: str,
                 work_dir: str,
                 cfg_file: Optional[str] = None,
                 train_dataset: Optional[Union[MsDataset, Dataset]] = None,
                 eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
                 model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
                 **kwargs):

        if isinstance(model, str):
            self.model_dir = self.get_or_download_model_dir(
                model, model_revision)
            if cfg_file is None:
                cfg_file = os.path.join(self.model_dir,
                                        ModelFile.CONFIGURATION)
        else:
            assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
            self.model_dir = os.path.dirname(cfg_file)

        BaseTrainer.__init__(self, cfg_file)

        self.model = self.build_model()
        self.work_dir = work_dir
        if kwargs.get('launcher', None) is not None:
            init_dist(kwargs['launcher'])
        _, world_size = get_dist_info()
        self._dist = world_size > 1

        device_name = kwargs.get('device', 'gpu')
        if self._dist:
            local_rank = get_local_rank()
            device_name = f'cuda:{local_rank}'
        self.device = create_device(device_name)

        if 'max_epochs' not in kwargs:
            assert hasattr(
                self.cfg.train, 'max_epochs'
            ), 'max_epochs is missing from the configuration file'
            self._max_epochs = self.cfg.train.max_epochs
        else:
            self._max_epochs = kwargs['max_epochs']
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        hparams_file = os.path.join(self.model_dir, 'hparams.yaml')
        overrides = {
            'output_folder':
            self.work_dir,
            'seed':
            self.cfg.train.seed,
            'lr':
            self.cfg.train.optimizer.lr,
            'weight_decay':
            self.cfg.train.optimizer.weight_decay,
            'clip_grad_norm':
            self.cfg.train.optimizer.clip_grad_norm,
            'factor':
            self.cfg.train.lr_scheduler.factor,
            'patience':
            self.cfg.train.lr_scheduler.patience,
            'dont_halve_until_epoch':
            self.cfg.train.lr_scheduler.dont_halve_until_epoch,
        }
        # load hyper params
        from hyperpyyaml import load_hyperpyyaml
        with open(hparams_file) as fin:
            self.hparams = load_hyperpyyaml(fin, overrides=overrides)
        # Create experiment directory
        sb.create_experiment_directory(
            experiment_directory=self.work_dir,
            hyperparams_to_save=hparams_file,
            overrides=overrides,
        )

        run_opts = {
            'debug': False,
            'device': 'cpu',
            'data_parallel_backend': False,
            'distributed_launch': False,
            'distributed_backend': 'nccl',
            'find_unused_parameters': False
        }
        if self.device.type == 'cuda':
            run_opts['device'] = f'{self.device.type}:{self.device.index}'
        self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs)
        self.hparams['epoch_counter'] = self.epoch_counter
        self.hparams['checkpointer'].add_recoverables(
            {'counter': self.epoch_counter})
        modules = self.model.as_dict()
        self.hparams['checkpointer'].add_recoverables(modules)
        # Brain class initialization
        self.separator = Separation(
            modules=modules,
            opt_class=self.hparams['optimizer'],
            hparams=self.hparams,
            run_opts=run_opts,
            checkpointer=self.hparams['checkpointer'],
        )

    def build_model(self) -> torch.nn.Module:
        """ Instantiate a pytorch model and return.
        """
        model = Model.from_pretrained(
            self.model_dir, cfg_dict=self.cfg, training=True)
        if isinstance(model, TorchModel) and hasattr(model, 'model'):
            return model.model
        elif isinstance(model, torch.nn.Module):
            return model

    def train(self, *args, **kwargs):
        self.separator.fit(
            self.epoch_counter,
            self.train_dataset,
            self.eval_dataset,
            train_loader_kwargs=self.hparams['dataloader_opts'],
            valid_loader_kwargs=self.hparams['dataloader_opts'],
        )

    def evaluate(self, checkpoint_path: str, *args,
                 **kwargs) -> Dict[str, float]:
        if checkpoint_path:
            self.hparams.checkpointer.checkpoints_dir = checkpoint_path
        else:
            self.model.load_check_point(device=self.device)
        value = self.separator.evaluate(
            self.eval_dataset,
            test_loader_kwargs=self.hparams['dataloader_opts'],
            min_key=EVAL_KEY)
        return {EVAL_KEY: value}


class Separation(sb.Brain):
    """A subclass of speechbrain.Brain implements training steps."""

    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""

        # Unpack lists and put tensors in the right device
        mix, mix_lens = mix
        mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)

        # Convert targets to tensor
        targets = torch.cat(
            [
                targets[i][0].unsqueeze(-1)
                for i in range(self.hparams.num_spks)
            ],
            dim=-1,
        ).to(self.device)

        # Add speech distortions
        if stage == sb.Stage.TRAIN:
            with torch.no_grad():
                if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
                    mix, targets = self.add_speed_perturb(targets, mix_lens)

                    mix = targets.sum(-1)

                if self.hparams.use_wavedrop:
                    mix = self.hparams.wavedrop(mix, mix_lens)

                if self.hparams.limit_training_signal_len:
                    mix, targets = self.cut_signals(mix, targets)

        # Separation
        mix_w = self.modules['encoder'](mix)
        est_mask = self.modules['masknet'](mix_w)
        mix_w = torch.stack([mix_w] * self.hparams.num_spks)
        sep_h = mix_w * est_mask

        # Decoding
        est_source = torch.cat(
            [
                self.modules['decoder'](sep_h[i]).unsqueeze(-1)
                for i in range(self.hparams.num_spks)
            ],
            dim=-1,
        )
        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]

        return est_source, targets

    def compute_objectives(self, predictions, targets):
        """Computes the sinr loss"""
        return self.hparams.loss(targets, predictions)

    # yapf: disable
    def fit_batch(self, batch):
        """Trains one batch"""
        # Unpacking batch list
        mixture = batch.mix_sig
        targets = [batch.s1_sig, batch.s2_sig]

        if self.hparams.num_spks == 3:
            targets.append(batch.s3_sig)

        if self.auto_mix_prec:
            with autocast():
                predictions, targets = self.compute_forward(
                    mixture, targets, sb.Stage.TRAIN)
                loss = self.compute_objectives(predictions, targets)
                # hard threshold the easy dataitems
                if self.hparams.threshold_byloss:
                    th = self.hparams.threshold
                    loss_to_keep = loss[loss > th]
                    if loss_to_keep.nelement() > 0:
                        loss = loss_to_keep.mean()
                    else:
                        print('loss has zero elements!!')
                else:
                    loss = loss.mean()

            # the fix for computational problems
            if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
                self.scaler.scale(loss).backward()
                if self.hparams.clip_grad_norm >= 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.modules.parameters(),
                        self.hparams.clip_grad_norm,
                    )
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                self.nonfinite_count += 1
                logger.info(
                    'infinite loss or empty loss! it happened {} times so far - skipping this batch'
                    .format(self.nonfinite_count))
                loss.data = torch.tensor(0).to(self.device)
        else:
            predictions, targets = self.compute_forward(
                mixture, targets, sb.Stage.TRAIN)
            loss = self.compute_objectives(predictions, targets)
            if self.hparams.threshold_byloss:
                th = self.hparams.threshold
                loss_to_keep = loss[loss > th]
                if loss_to_keep.nelement() > 0:
                    loss = loss_to_keep.mean()
            else:
                loss = loss.mean()
            # the fix for computational problems
            if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
                loss.backward()
                if self.hparams.clip_grad_norm >= 0:
                    torch.nn.utils.clip_grad_norm_(self.modules.parameters(),
                                                   self.hparams.clip_grad_norm)
                self.optimizer.step()
            else:
                self.nonfinite_count += 1
                logger.info(
                    'infinite loss or empty loss! it happened {} times so far - skipping this batch'
                    .format(self.nonfinite_count))
                loss.data = torch.tensor(0).to(self.device)
        self.optimizer.zero_grad()
        return loss.detach().cpu()
    # yapf: enable

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        snt_id = batch.id
        mixture = batch.mix_sig
        targets = [batch.s1_sig, batch.s2_sig]
        if self.hparams.num_spks == 3:
            targets.append(batch.s3_sig)

        with torch.no_grad():
            predictions, targets = self.compute_forward(
                mixture, targets, stage)
            loss = self.compute_objectives(predictions, targets)

        # Manage audio file saving
        if stage == sb.Stage.TEST and self.hparams.save_audio:
            if hasattr(self.hparams, 'n_audio_to_save'):
                if self.hparams.n_audio_to_save > 0:
                    self.save_audio(snt_id[0], mixture, targets, predictions)
                    self.hparams.n_audio_to_save += -1
            else:
                self.save_audio(snt_id[0], mixture, targets, predictions)

        return loss.mean().detach()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {'si-snr': stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            # Learning rate annealing
            if isinstance(self.hparams.lr_scheduler,
                          schedulers.ReduceLROnPlateau):
                current_lr, next_lr = self.hparams.lr_scheduler(
                    [self.optimizer], epoch, stage_loss)
                schedulers.update_learning_rate(self.optimizer, next_lr)
            else:
                # if we do not use the reducelronplateau, we do not change the lr
                current_lr = self.hparams.optimizer.optim.param_groups[0]['lr']

            self.hparams.train_logger.log_stats(
                stats_meta={
                    'epoch': epoch,
                    'lr': current_lr
                },
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={'si-snr': stage_stats['si-snr']},
                min_keys=['si-snr'],
            )

    def add_speed_perturb(self, targets, targ_lens):
        """Adds speed perturbation and random_shift to the input signals"""

        min_len = -1
        recombine = False

        if self.hparams.use_speedperturb:
            # Performing speed change (independently on each source)
            new_targets = []
            recombine = True

            for i in range(targets.shape[-1]):
                new_target = self.hparams.speedperturb(targets[:, :, i],
                                                       targ_lens)
                new_targets.append(new_target)
                if i == 0:
                    min_len = new_target.shape[-1]
                else:
                    if new_target.shape[-1] < min_len:
                        min_len = new_target.shape[-1]

            if self.hparams.use_rand_shift:
                # Performing random_shift (independently on each source)
                recombine = True
                for i in range(targets.shape[-1]):
                    rand_shift = torch.randint(self.hparams.min_shift,
                                               self.hparams.max_shift, (1, ))
                    new_targets[i] = new_targets[i].to(self.device)
                    new_targets[i] = torch.roll(
                        new_targets[i], shifts=(rand_shift[0], ), dims=1)

            # Re-combination
            if recombine:
                if self.hparams.use_speedperturb:
                    targets = torch.zeros(
                        targets.shape[0],
                        min_len,
                        targets.shape[-1],
                        device=targets.device,
                        dtype=torch.float,
                    )
                for i, new_target in enumerate(new_targets):
                    targets[:, :, i] = new_targets[i][:, 0:min_len]

        mix = targets.sum(-1)
        return mix, targets

    def cut_signals(self, mixture, targets):
        """This function selects a random segment of a given length within the mixture.
        The corresponding targets are selected accordingly"""
        randstart = torch.randint(
            0,
            1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
            (1, ),
        ).item()
        targets = targets[:, randstart:randstart
                          + self.hparams.training_signal_len, :]
        mixture = mixture[:, randstart:randstart
                          + self.hparams.training_signal_len]
        return mixture, targets

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the neural networks"""
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

    def save_results(self, test_data):
        """This script computes the SDR and SI-SNR metrics and saves
        them into a csv file"""

        # This package is required for SDR computation
        from mir_eval.separation import bss_eval_sources

        # Create folders where to store audio
        save_file = os.path.join(self.hparams.output_folder,
                                 'test_results.csv')

        # Variable init
        all_sdrs = []
        all_sdrs_i = []
        all_sisnrs = []
        all_sisnrs_i = []
        csv_columns = ['snt_id', 'sdr', 'sdr_i', 'si-snr', 'si-snr_i']

        test_loader = sb.dataio.dataloader.make_dataloader(
            test_data, **self.hparams.dataloader_opts)

        with open(save_file, 'w') as results_csv:
            writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
            writer.writeheader()

            # Loop over all test sentence
            with tqdm(test_loader, dynamic_ncols=True) as t:
                for i, batch in enumerate(t):

                    # Apply Separation
                    mixture, mix_len = batch.mix_sig
                    snt_id = batch.id
                    targets = [batch.s1_sig, batch.s2_sig]
                    if self.hparams.num_spks == 3:
                        targets.append(batch.s3_sig)

                    with torch.no_grad():
                        predictions, targets = self.compute_forward(
                            batch.mix_sig, targets, sb.Stage.TEST)

                    # Compute SI-SNR
                    sisnr = self.compute_objectives(predictions, targets)

                    # Compute SI-SNR improvement
                    mixture_signal = torch.stack(
                        [mixture] * self.hparams.num_spks, dim=-1)
                    mixture_signal = mixture_signal.to(targets.device)
                    sisnr_baseline = self.compute_objectives(
                        mixture_signal, targets)
                    sisnr_i = sisnr.mean() - sisnr_baseline.mean()

                    # Compute SDR
                    sdr, _, _, _ = bss_eval_sources(
                        targets[0].t().cpu().numpy(),
                        predictions[0].t().detach().cpu().numpy(),
                    )

                    sdr_baseline, _, _, _ = bss_eval_sources(
                        targets[0].t().cpu().numpy(),
                        mixture_signal[0].t().detach().cpu().numpy(),
                    )

                    sdr_i = sdr.mean() - sdr_baseline.mean()

                    # Saving on a csv file
                    row = {
                        'snt_id': snt_id[0],
                        'sdr': sdr.mean(),
                        'sdr_i': sdr_i,
                        'si-snr': -sisnr.item(),
                        'si-snr_i': -sisnr_i.item(),
                    }
                    writer.writerow(row)

                    # Metric Accumulation
                    all_sdrs.append(sdr.mean())
                    all_sdrs_i.append(sdr_i.mean())
                    all_sisnrs.append(-sisnr.item())
                    all_sisnrs_i.append(-sisnr_i.item())

                row = {
                    'snt_id': 'avg',
                    'sdr': np.array(all_sdrs).mean(),
                    'sdr_i': np.array(all_sdrs_i).mean(),
                    'si-snr': np.array(all_sisnrs).mean(),
                    'si-snr_i': np.array(all_sisnrs_i).mean(),
                }
                writer.writerow(row)

        logger.info('Mean SISNR is {}'.format(np.array(all_sisnrs).mean()))
        logger.info('Mean SISNRi is {}'.format(np.array(all_sisnrs_i).mean()))
        logger.info('Mean SDR is {}'.format(np.array(all_sdrs).mean()))
        logger.info('Mean SDRi is {}'.format(np.array(all_sdrs_i).mean()))

    def save_audio(self, snt_id, mixture, targets, predictions):
        'saves the test audio (mixture, targets, and estimated sources) on disk'

        # Create output folder
        save_path = os.path.join(self.hparams.save_folder, 'audio_results')
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        for ns in range(self.hparams.num_spks):

            # Estimated source
            signal = predictions[0, :, ns]
            signal = signal / signal.abs().max() * 0.5
            save_file = os.path.join(
                save_path, 'item{}_source{}hat.wav'.format(snt_id, ns + 1))
            torchaudio.save(save_file,
                            signal.unsqueeze(0).cpu(),
                            self.hparams.sample_rate)

            # Original source
            signal = targets[0, :, ns]
            signal = signal / signal.abs().max() * 0.5
            save_file = os.path.join(
                save_path, 'item{}_source{}.wav'.format(snt_id, ns + 1))
            torchaudio.save(save_file,
                            signal.unsqueeze(0).cpu(),
                            self.hparams.sample_rate)

        # Mixture
        signal = mixture[0][0, :]
        signal = signal / signal.abs().max() * 0.5
        save_file = os.path.join(save_path, 'item{}_mix.wav'.format(snt_id))
        torchaudio.save(save_file,
                        signal.unsqueeze(0).cpu(), self.hparams.sample_rate)
