import os
from typing import Union

import torch
from deepspeed import DeepSpeedEngine
from megatron_util import mpu
from torch import nn

from modelscope.metainfo import Trainers
from modelscope.models.base import TorchModel
from modelscope.models.nlp.plug import DistributedPlug
from modelscope.models.nlp.plug.backbone import BertLayerNorm
from modelscope.models.nlp.plug.generator import TextGenerator
from modelscope.utils.constant import ModeKeys
from ..base import TRAINERS
from ..nlp_trainer import NlpEpochBasedTrainer


@TRAINERS.register_module(module_name=Trainers.nlp_plug_trainer)
class PlugTrainer(NlpEpochBasedTrainer):

    def build_model(self) -> Union[nn.Module, TorchModel]:
        rank = int(os.environ.get('LOCAL_RANK', -1))
        master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
        master_port = os.environ.get('MASTER_PORT', '29500')
        model = DistributedPlug(
            self.model_dir,
            rank,
            master_ip=master_ip,
            master_port=master_port,
            **self.cfg.model)
        self.unwrap_module(model.model).model_dir = self.model_dir
        return model.model

    def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
        from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP
        return DDP(model)

    def _get_params_for_weight_decay_optimization(self, module):

        weight_decay_params = {'params': []}
        no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
        for module_ in module.modules():
            if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)):
                no_weight_decay_params['params'].extend([
                    p for p in list(module_._parameters.values())
                    if p is not None
                ])
            else:
                weight_decay_params['params'].extend([
                    p for n, p in list(module_._parameters.items())
                    if p is not None and 'mask_score' not in n
                    and 'mask' not in n and n != 'bias'
                ])
                no_weight_decay_params['params'].extend([
                    p for n, p in list(module_._parameters.items())
                    if p is not None and n == 'bias'
                ])

        return weight_decay_params, no_weight_decay_params

    def create_optimizer_and_scheduler(self):
        optimizer, lr_scheduler = self.optimizers
        optimizer_cfg = self.cfg.train.get('optimizer', None)
        # optim_options = {}
        if optimizer_cfg is not None:
            optim_options = optimizer_cfg.pop('options', {})
        from deepspeed.ops.adam import DeepSpeedCPUAdam
        model = self.model

        embeddings = model.module.model.bert.embeddings
        layers = model.module.model.bert.encoder.layer
        dec_layers = model.module.model.decoder.decoder
        param_groups = []
        param_groups += list(
            self._get_params_for_weight_decay_optimization(layers))
        param_groups += list(
            self._get_params_for_weight_decay_optimization(embeddings))
        param_groups += list(
            self._get_params_for_weight_decay_optimization(dec_layers))

        for param_group in param_groups:
            for param in param_group['params']:
                if not hasattr(param, 'model_parallel'):
                    param.model_parallel = False
        optimizer = DeepSpeedCPUAdam(
            param_groups,
            lr=optimizer_cfg.lr,
            weight_decay=optimizer_cfg.weight_decay)

        lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None)

        if lr_scheduler_cfg is not None:
            assert optimizer is not None
            lr_options = lr_scheduler_cfg.pop('options', {})
        from modelscope.models.nlp.plug.AnnealingLR import AnnealingLR
        num_iters = self.max_iters
        lr_scheduler = AnnealingLR(
            optimizer,
            start_lr=optimizer_cfg.lr,
            warmup_iter=lr_scheduler_cfg.warmup * num_iters,
            num_iters=num_iters,
            decay_style=lr_scheduler_cfg.decay_style,
            last_iter=-1)

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        return self.optimizer, self.lr_scheduler, optim_options, lr_options

    def _get_masks_and_position_ids(self, data, eod_token):
        # Extract batch size and sequence length.
        batch_size, seq_length = data.size()

        # Attention mask (lower triangular).
        att_mask_batch = 1
        attention_mask = torch.tril(
            torch.ones((att_mask_batch, seq_length, seq_length),
                       device=data.device)).view(att_mask_batch, 1, seq_length,
                                                 seq_length)

        # Loss mask.
        loss_mask = torch.ones(
            data.size(), dtype=torch.float, device=data.device)
        loss_mask[data == eod_token] = 0.0

        # Position ids.
        position_ids = torch.arange(
            seq_length, dtype=torch.long, device=data.device)
        position_ids = position_ids.unsqueeze(0).expand_as(data)
        return attention_mask, loss_mask, position_ids

    def train_step(self, model, inputs):
        self._mode = ModeKeys.TRAIN
        # format inputs
        checkpoint_activations = getattr(self.cfg.train,
                                         'checkpoint_activations', True)
        tgt_tokens = inputs['labels'][:, :-1].contiguous()
        tgt_labels = inputs['labels'][:, 1:].contiguous()
        tgt_attention_mask, dec_loss_mask, position_ids = self._get_masks_and_position_ids(
            tgt_tokens, 0)
        if getattr(self.cfg.train, 'fp16', None):
            tgt_attention_mask = tgt_attention_mask.half()

        # forward step
        _, output = model(
            inputs['input_ids'],
            None,
            inputs['attention_mask'],
            tgt_tokens,
            position_ids,
            tgt_attention_mask,
            checkpoint_activations=checkpoint_activations)

        losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                                  tgt_labels)
        dec_loss_mask = dec_loss_mask.view(-1)
        loss = torch.sum(losses.view(-1) * dec_loss_mask) / dec_loss_mask.sum()

        # add model output info to log
        self.train_outputs = {'loss': loss}
        self.log_buffer.update(self.train_outputs)

    def evaluation_step(self, data):
        # wrapper 1: DeepspeedEngine, wrapper 2: DDP
        # model = self.model.module
        if isinstance(self.model, DeepSpeedEngine):
            model = self.model.module
        else:
            model = self.model

        model.eval()

        # model: fp16 wapper; model.module : distributedPlug
        vocab_size = self.unwrap_module(self.model).config.original_vocab_size
        batch_size = data['input_ids'].shape[0]
        beam_generator = TextGenerator(model,
                                       self.eval_preprocessor.nlp_tokenizer,
                                       None)

        with torch.no_grad():
            tokens = data['input_ids'].long()
            padding_mask = data['attention_mask'].byte()
            target_ids = data['labels'].long()
            target_labels = target_ids[:, 1:].contiguous()
            encoder_inputs = [tokens, None, padding_mask]
            result = beam_generator.translate_batch(encoder_inputs)
            pred_list = result['predictions']
            target_list = target_labels.cpu().numpy().tolist()
            result['preds'] = []
            data['tgts'] = []
            for i in range(batch_size):
                pred_ids = pred_list[i][0]
                pred_ids[pred_ids > vocab_size - 1] = 100
                pred_ids = pred_ids.cpu().numpy().tolist()

                gold_string = self.eval_preprocessor.decode(
                    target_list[i], skip_special_tokens=True)
                pred_string = self.eval_preprocessor.decode(
                    pred_ids, skip_special_tokens=True)
                result['preds'].append(pred_string)
                data['tgts'].append(gold_string)
        return result
