# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from modelscope.utils.logger import get_logger

logger = get_logger()


def _symmetric_kl_div(logits1, logits2, attention_mask=None):
    """
    Calculate two logits' the KL div value symmetrically.
    :param logits1: The first logit.
    :param logits2: The second logit.
    :param attention_mask: An optional attention_mask which is used to mask some element out.
    This is usually useful in token_classification tasks.
    If the shape of logits is [N1, N2, ... Nn, D], the shape of attention_mask should be [N1, N2, ... Nn]
    :return: The mean loss.
    """
    labels_num = logits1.shape[-1]
    KLDiv = nn.KLDivLoss(reduction='none')
    loss = torch.sum(
        KLDiv(nn.LogSoftmax(dim=-1)(logits1),
              nn.Softmax(dim=-1)(logits2)),
        dim=-1) + torch.sum(
            KLDiv(nn.LogSoftmax(dim=-1)(logits2),
                  nn.Softmax(dim=-1)(logits1)),
            dim=-1)
    if attention_mask is not None:
        loss = torch.sum(
            loss * attention_mask) / torch.sum(attention_mask) / labels_num
    else:
        loss = torch.mean(loss) / labels_num
    return loss


def compute_adv_loss(embedding,
                     model,
                     ori_logits,
                     ori_loss,
                     adv_grad_factor,
                     adv_bound=None,
                     sigma=5e-6,
                     **kwargs):
    """
    Calculate the adv loss of the model.
    :param embedding: Original sentence embedding
    :param model: The model, or the forward function(including decoder/classifier),
            accept kwargs as input, output logits
    :param ori_logits: The original logits outputted from the model function
    :param ori_loss: The original loss
    :param adv_grad_factor: This factor will be multiplied by the KL loss grad and then the result will be added to
            the original embedding.
            More details please check:https://arxiv.org/abs/1908.04577
            The range of this value always be 1e-3~1e-7
    :param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding.
            If not proveded, 2 * sigma will be used as the adv_bound factor
    :param sigma: The std factor used to produce a 0 mean normal distribution.
            If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
    :param kwargs: the input param used in model function
    :return: The original loss adds the adv loss
    """
    adv_bound = adv_bound if adv_bound is not None else 2 * sigma
    embedding_1 = embedding + embedding.data.new(embedding.size()).normal_(
        0, sigma)  # 95% in +- 1e-5
    kwargs.pop('input_ids')
    if 'inputs_embeds' in kwargs:
        kwargs.pop('inputs_embeds')
    with_attention_mask = False if 'with_attention_mask' not in kwargs else kwargs[
        'with_attention_mask']
    attention_mask = kwargs['attention_mask']
    if not with_attention_mask:
        attention_mask = None
    if 'with_attention_mask' in kwargs:
        kwargs.pop('with_attention_mask')
    outputs = model(**kwargs, inputs_embeds=embedding_1)
    v1_logits = outputs.logits
    loss = _symmetric_kl_div(ori_logits, v1_logits, attention_mask)
    emb_grad = torch.autograd.grad(loss, embedding_1)[0].data
    emb_grad_norm = emb_grad.norm(
        dim=2, keepdim=True, p=float('inf')).max(
            1, keepdim=True)[0]
    is_nan = torch.any(torch.isnan(emb_grad_norm))
    if is_nan:
        logger.warning('Nan occurred when calculating adv loss.')
        return ori_loss
    emb_grad = emb_grad / (emb_grad_norm + 1e-6)
    embedding_2 = embedding_1 + adv_grad_factor * emb_grad
    embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
    embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
    outputs = model(**kwargs, inputs_embeds=embedding_2)
    adv_logits = outputs.logits
    adv_loss = _symmetric_kl_div(ori_logits, adv_logits, attention_mask)
    return ori_loss + adv_loss


def compute_adv_loss_pair(embedding,
                          model,
                          start_logits,
                          end_logits,
                          ori_loss,
                          adv_grad_factor,
                          adv_bound=None,
                          sigma=5e-6,
                          **kwargs):
    """
    Calculate the adv loss of the model. This function is used in the pair logits scenario.
    :param embedding: Original sentence embedding
    :param model: The model, or the forward function(including decoder/classifier),
            accept kwargs as input, output logits
    :param start_logits: The original start logits outputted from the model function
    :param end_logits: The original end logits outputted from the model function
    :param ori_loss: The original loss
    :param adv_grad_factor: This factor will be multiplied by the KL loss grad and then the result will be added to
            the original embedding.
            More details please check:https://arxiv.org/abs/1908.04577
            The range of this value always be 1e-3~1e-7
    :param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding.
            If not proveded, 2 * sigma will be used as the adv_bound factor
    :param sigma: The std factor used to produce a 0 mean normal distribution.
            If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
    :param kwargs: the input param used in model function
    :return: The original loss adds the adv loss
    """
    adv_bound = adv_bound if adv_bound is not None else 2 * sigma
    embedding_1 = embedding + embedding.data.new(embedding.size()).normal_(
        0, sigma)  # 95% in +- 1e-5
    kwargs.pop('input_ids')
    if 'inputs_embeds' in kwargs:
        kwargs.pop('inputs_embeds')
    outputs = model(**kwargs, inputs_embeds=embedding_1)
    v1_logits_start, v1_logits_end = outputs.logits
    loss = _symmetric_kl_div(start_logits,
                             v1_logits_start) + _symmetric_kl_div(
                                 end_logits, v1_logits_end)
    loss = loss / 2
    emb_grad = torch.autograd.grad(loss, embedding_1)[0].data
    emb_grad_norm = emb_grad.norm(
        dim=2, keepdim=True, p=float('inf')).max(
            1, keepdim=True)[0]
    is_nan = torch.any(torch.isnan(emb_grad_norm))
    if is_nan:
        logger.warning('Nan occurred when calculating pair adv loss.')
        return ori_loss
    emb_grad = emb_grad / emb_grad_norm
    embedding_2 = embedding_1 + adv_grad_factor * emb_grad
    embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
    embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
    outputs = model(**kwargs, inputs_embeds=embedding_2)
    adv_logits_start, adv_logits_end = outputs.logits
    adv_loss = _symmetric_kl_div(start_logits,
                                 adv_logits_start) + _symmetric_kl_div(
                                     end_logits, adv_logits_end)
    return ori_loss + adv_loss
