# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_div_ops = (exir_ops.edge.aten.div.Tensor,)
aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor)


def get_div_decomposition(op) -> tuple:
    """Returns the (reciprocal_op, mul_op), where the ops depends on if the div
    op is in exir_ops torch.ops.aten.
    """
    if op in edge_div_ops:
        return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
    if op in aten_div_ops:
        return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
    raise RuntimeError(f"Can't get div decomposition for op {op}")


class DecomposeDivPass(ArmPass):
    """This pass decomposes div into a mul and a reciprocal node.

    Example:
        y = div(a,b)
    Becomes:
        x = reciprocal(b)
        y = mul(a,x)

    """

    _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}

    def call_operator(self, op, args, kwargs, meta):
        if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform(
            meta
        ):
            return super().call_operator(op, args, kwargs, meta)

        reciprocal_op, mul_op = get_div_decomposition(op)

        numerator = args[0]
        denominator = args[1]
        reciprocal = super().call_operator(
            reciprocal_op, (denominator,), {}, meta, updated=True
        )

        return super().call_operator(
            mul_op, (numerator, reciprocal), {}, meta, updated=True
        )
