# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import math
from typing import Set, Type

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch._decomp import get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx


class DecomposeScaledDotProductAttention(ExportPass):
    """
    Decompose from scaled_dot_product_attention to multiple nodes.
    """

    _passes_required_after: Set[Type[ExportPass]] = set()

    def __init__(self, allow_non_fake_inputs: bool = True) -> None:
        super().__init__()
        # With allow_non_fake_inputs=False, we don't get _unsafe_view ops
        # in the graph, we allow disabling it here.
        self._allow_non_fake_inputs = allow_non_fake_inputs

    def call(
        self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True
    ) -> PassResult:
        graph = graph_module.graph
        for node in list(graph.nodes):
            if node.target != torch.ops.aten.scaled_dot_product_attention.default:
                continue
            self._decompose_sdpa_node(graph_module, node, allow_non_fake_inputs)

        graph.eliminate_dead_code()
        graph_module.recompile()
        return PassResult(graph_module, True)

    def _decompose_sdpa_node(
        self,
        graph_module: torch.fx.GraphModule,
        node: torch.fx.Node,
        allow_non_fake_inputs: bool,
    ) -> None:
        graph = graph_module.graph
        input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes)
        scale = node.kwargs.get("scale", None)

        # refer to pytorch/test/test_decomp.py
        decomposed_module = make_fx(
            node.target,
            decomposition_table=get_decompositions(  # pyre-fixme[6]
                [
                    torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
                ]
            ),
            tracing_mode="fake",
            _allow_non_fake_inputs=allow_non_fake_inputs,
        )(*input_tensors)

        with graph.inserting_before(node):
            name_to_input_tensor_map = {}
            for i, arg in enumerate(node.args):
                name_to_input_tensor_map[f"arg{i}_1"] = arg

            decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
            last_decomposed_node = None
            # Create a mapping from input nodes in decomposed module to original nodes.
            # In decomposed module, there are only input tensors for placeholder op.
            for decomposed_node in decomposed_module.graph.nodes:
                if decomposed_node.op == "placeholder":
                    decomposed_node_to_subgraph_node[decomposed_node] = (
                        name_to_input_tensor_map[decomposed_node.name]
                    )

                if decomposed_node.op == "output":
                    last_decomposed_node = decomposed_node.args[0]

            # Copy node from decompose graph module
            for decomposed_node in decomposed_module.graph.nodes:
                node.meta["nn_module_stack"] = decomposed_node.meta.get(
                    "nn_module_stack"
                )
                if decomposed_node.op == "placeholder":
                    continue

                if decomposed_node.op == "output" and last_decomposed_node is not None:
                    for user in node.users.copy():
                        user.replace_input_with(
                            node,
                            decomposed_node_to_subgraph_node[last_decomposed_node],
                        )
                    continue

                if scale is not None and decomposed_node.target in [
                    torch.ops.aten.mul.Scalar
                ]:
                    new_args = list(decomposed_node.args)
                    # Based on the implementation of _scaled_dot_product_attention_math,
                    # the scale is applied to q and k before matmul.
                    # refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
                    new_args[1] = math.sqrt(scale)
                    decomposed_node.args = tuple(new_args)

                subgraph_node = graph.node_copy(
                    decomposed_node,
                    arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
                )
                subgraph_node.meta["source_fn_stack"] = [
                    (subgraph_node, subgraph_node.target)
                ]
                decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node

            graph.erase_node(node)
