# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle


def is_flash_attn_available():
    try:
        import os

        if "npu" in paddle.get_device():  # NOTE: flash attn has not been tested yet
            for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
                if lib.endswith(".so"):
                    paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
                        lib
                    )
            from paddle.base import core

            def flash_attention_npu(
                query,
                key,
                value,
                dropout=0.0,
                causal=False,
                return_softmax=False,
                *,
                fixed_seed_offset=None,
                rng_name="",
                training=True,
                name=None,
                attn_mask=None,
                is_varlen=False,
                batch_size=None,
                seq_length=None,
            ):

                is_triangle_upper_mask = True

                if is_varlen:
                    if len(query.shape) == 4:
                        B, S, N, D = query.shape
                        query = query.reshape([B * S, N, D])
                        key = key.reshape([B * S, N, D])
                        value = value.reshape([B * S, N, D])
                    else:
                        assert batch_size is not None
                        assert seq_length is not None
                        B = batch_size
                        S = seq_length
                    actual_seq_q_len = actual_seq_kv_len = list(range(S, B * S + S, S))
                else:
                    actual_seq_q_len = actual_seq_kv_len = []

                out = core.eager._run_custom_op(
                    "flash_attention_npu",
                    query,
                    key,
                    value,
                    fixed_seed_offset,
                    attn_mask,
                    actual_seq_q_len,
                    actual_seq_kv_len,
                    dropout,
                    causal,
                    return_softmax,
                    not training,
                    is_triangle_upper_mask,
                    is_varlen,
                )[0]

                return out

            q = paddle.rand((1, 4, 2, 8)).astype("bfloat16")
            _ = flash_attention_npu(q, q, q, 0.9, False, False)
            paddle.nn.functional.flash_attention_npu = flash_attention_npu
            return True
        q = paddle.rand((1, 4, 2, 8)).astype("bfloat16")
        _ = paddle.nn.functional.flash_attention.flash_attention(
            q, q, q, 0.9, False, False
        )
        return True
    except:
        return False


HAS_FLASH_ATTN = is_flash_attn_available()


def has_flash_attn_func():
    if HAS_FLASH_ATTN:
        try:
            if "npu" in paddle.get_device():
                flash_attn_func_npu = paddle.nn.functional.flash_attention_npu
                return flash_attn_func_npu, flash_attn_func_npu
            else:
                from paddle.nn.functional.flash_attention import (
                    flash_attention as flash_attn_func,
                )
                from paddle.nn.functional.flash_attention import (
                    flash_attn_unpadded as flash_attn_varlen_func,
                )

                return flash_attn_func, flash_attn_varlen_func
        except:
            return None, None
    else:
        return None, None
