# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass  # noqa
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
    ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

# For MI case
edge_acosh_op = exir_ops.edge.aten.acosh.default


class DecomposeAcoshPass(ArmPass):
    """Decomposes acosh to supported TOSA-operations.

    This decomposition is based on the mathematical identity:
        acosh(x) = log(x + sqrt((x-1)(x+1))

    """

    _passes_required_after: Set[Type[ExportPass]] = {
        DecomposeSqrtPass,
        InsertTableOpsPass,
        MatchArgRanksPass,
        ReplaceScalarWithTensorByProfilePass,
        MatchArgDtypePass,
    }

    def call_operator(self, op, args, kwargs, meta, updated=False):

        if op is not edge_acosh_op:
            return super().call_operator(op, args, kwargs, meta, updated)

        if self._is_quantized_meta(meta):
            # If quantized, node should be replace by table op
            return super().call_operator(op, args, kwargs, meta, updated)

        log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = (
            exir_ops.edge.aten.log.default,
            exir_ops.edge.aten.sqrt.default,
            exir_ops.edge.aten.mul.Tensor,
            exir_ops.edge.aten.sub.Scalar,
            exir_ops.edge.aten.add.Tensor,
            exir_ops.edge.aten.add.Scalar,
        )

        x = args[0]

        # (x-1)(x+1)
        sub = super().call_operator(sub_op, (x, 1.0), {}, meta, True)
        add = super().call_operator(add_op_scalar, (x, 1.0), {}, meta, True)
        mul = super().call_operator(mul_op, (sub, add), {}, meta, True)

        # sqrt((x-1)(x+1))
        sqrt = super().call_operator(sqrt_op, (mul,), {}, meta, True)

        # x + sqrt((x-1)(x+1))
        add = super().call_operator(add_op, (x, sqrt), {}, meta, True)

        # out = ln(x + sqrt((x-1)(x+1))
        out = super().call_operator(log_op, (add,), {}, meta, True)

        return out
