# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from dataclasses import dataclass, field
from typing import cast, List, Optional, Sequence, Set, Type

# Import these for the cadence function signatures.
import executorch.backends.cadence.aot.ops_registrations  # noqa: F401

import torch
import torch.fx
from executorch.backends.cadence.aot.pass_utils import (
    CadencePassAttribute,
    get_arg,
    register_cadence_pass,
    RemoveOrReplacePassInterface,
    set_arg,
)

from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes import dead_code_elimination_pass
from torch.fx.node import Node


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveCloneOpsTransformImported(ExportPass):
    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        finalize_passes: List[PassType] = [
            RemoveCloneOpsTransform(eliminate_quant_dequant_pairs=False),
        ]
        result = PassManager(passes=finalize_passes)(graph_module)
        dead_code_elimination_pass(result.graph_module)
        return result


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveDetachCopyPass(RemoveOrReplacePassInterface):
    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.detach_copy.default]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        node.replace_all_uses_with(input_node)
        return True


# The following class consolidates passes to remove ops that are redundant:
# either by the virtue of the operation they perform, or redundant in the
# context of inference.
class RemoveRedundantOps:
    passes = [
        RemoveDetachCopyPass,
    ]


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface):
    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.cat.default]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        # Get the cat inputs (first argument is a list of tensors)
        cat_inputs_arg = node.args[0]

        # Assert that cat_inputs_arg is iterable
        assert isinstance(
            cat_inputs_arg, (list, tuple)
        ), "cat_inputs_arg must be a sequence type"

        # Filter out zero-sized tensors
        cat_inputs: list[Node] = []
        for arg in cat_inputs_arg:
            if isinstance(arg, Node) and arg.meta.get("val") is not None:
                if arg.meta["val"].numel() > 0:
                    cat_inputs.append(arg)

        # If all tensors were empty, create a full op with the right shape
        if not cat_inputs:
            empty_shape = node.meta["val"].shape
            dtype = node.meta["val"].dtype
            # Create a new full node
            with node.graph.inserting_before(node):
                full_node = node.graph.call_function(
                    exir_ops.edge.aten.full.default,
                    args=(tuple(empty_shape), 0),
                    kwargs={"dtype": dtype},
                )
                full_node.meta = node.meta.copy()
            node.replace_all_uses_with(full_node)
            return True

        # If only one tensor remains, replace with it
        if len(cat_inputs) == 1:
            node.replace_all_uses_with(cat_inputs[0])
            return True

        # If the number of inputs changed, update the cat args
        if len(cat_inputs) < len(cat_inputs_arg):
            # Update the first argument with filtered inputs
            new_args = list(node.args)
            new_args[0] = cat_inputs
            node.args = tuple(new_args)
            return True

        # No changes needed
        return False


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveNopExpandOpPass(RemoveOrReplacePassInterface):
    """
    For an expand op, if the operator shape matches the expand shape, then the
    expand is a nop.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [
            exir_ops.edge.aten.expand_copy.default,
            exir_ops.edge.aten.expand.default,
        ]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        if input_node.meta["val"].shape == node.meta["val"].shape:
            node.replace_all_uses_with(input_node)
            return True
        return False


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveToOpsPass(RemoveOrReplacePassInterface):
    # aten.to.* as of now are all nops
    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [
            exir_ops.edge.aten.to.dtype,
            exir_ops.edge.aten.to.dtype_layout,
        ]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        node.replace_all_uses_with(input_node)
        return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveZeroSizedConstantPadNd(RemoveOrReplacePassInterface):
    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.constant_pad_nd.default]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        # Get padding argument (second argument)
        if len(node.args) < 2:
            return False

        padding = node.args[1]
        if not isinstance(padding, (list, tuple)):
            return False

        # If any padding value is non-zero, keep the node
        if any(x != 0 for x in padding):
            return False

        # All padding is zero, replace with input
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        node.replace_all_uses_with(input_node)
        return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopSliceOrViewOpPass(RemoveOrReplacePassInterface):
    """
    Remove slice ops that are more like views, and view ops that do not change the shape
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [
            exir_ops.edge.aten.slice_copy.Tensor,
            exir_ops.edge.aten.view_copy.default,
        ]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        changed = False
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        if input_node.meta["val"].shape == node.meta["val"].shape:
            node.replace_all_uses_with(input_node)
            changed = True

        return changed


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopLinalgVectorNormOpPass(RemoveOrReplacePassInterface):
    """
    If the norm is applied over a dimension that is size 1, it can be eliminated.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.linalg_vector_norm.default]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        # If the op has three args or less, it can't be a nop
        if len(node.args) <= 3:
            return False
        # If dim is None, or keepdim is False, it is not a nop
        dim = cast(Optional[tuple[int, ...]], node.args[2])
        keepdim = cast(bool, node.args[3])
        if dim is None or not keepdim:
            return False

        # If the norm has 4 args and keepdim is True, check if dim is not None
        # and if the dimensions in dim are size 1. If not, the norm is not a nop.
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        shape = input_node.meta["val"].shape
        if len(node.args) < 4:
            for d in dim:
                if shape[d] != 1:
                    return False

        node.replace_all_uses_with(input_node)
        return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveContiguousOpPass(RemoveOrReplacePassInterface):
    """
    This is based on the assumption that all tensors are contiguous in ExecuTorch
    and after cadence passes, and we should revisit this if that assumption is no longer true.
    This causes the model to not be runnable with the arguments given to the
    original graph module.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.contiguous.default]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        node.replace_all_uses_with(input_node)
        return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface):
    """

    alias_copy is a no-op and can be removed.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.alias_copy.default]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        node.replace_all_uses_with(input_node)
        return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopRequantizeOpPass(RemoveOrReplacePassInterface):
    """
    For a requantize op, if the following three conditions are satisfied:
    1. the in_scale matches the out_scale
    2. the in_zero_point matches the out_zero_point
    3. the dtypes of the input and output tensors are the same
    then the requantize op is redundant, and can be eliminated
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.cadence.requantize.per_tensor]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input_node = node.args[0]
        assert isinstance(input_node, Node)
        in_scale = node.args[1]
        in_zero_point = node.args[2]
        out_scale = node.args[3]
        out_zero_point = node.args[4]
        out_dtype = node.args[5]
        in_dtype = input_node.meta["val"].dtype
        # Check the three conditions
        if (
            in_scale == out_scale
            and in_zero_point == out_zero_point
            and in_dtype == out_dtype
        ):
            node.replace_all_uses_with(input_node)
            return True
        return False


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopMulOpPass(RemoveOrReplacePassInterface):
    """
    If a mul op is multiplying two tensors with the same shape and one
    of those tensors is all zeros, return the zero tensor instead.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.mul.Tensor]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input1 = node.args[0]
        input2 = node.args[1]
        assert isinstance(input1, Node)
        assert isinstance(input2, Node)

        # Check if both inputs have the same shape
        if input1.meta["val"].shape != input2.meta["val"].shape:
            return False

        # Check if one of the inputs is a zero tensor
        if input1.target == exir_ops.edge.aten.full.default:
            if input1.args[1] == 0:
                node.replace_all_uses_with(input1)
                return True
        elif input2.target == exir_ops.edge.aten.full.default:
            if input2.args[1] == 0:
                node.replace_all_uses_with(input2)
                return True

        return False


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopAddOpPass(RemoveOrReplacePassInterface):
    """
    If an add op is adding two tensors with the same shape and one
    of those tensors is all zeros, return the other tensor instead.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.add.Tensor]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        input1 = node.args[0]
        input2 = node.args[1]
        assert isinstance(input1, Node)
        assert isinstance(input2, Node)

        # Check if both inputs have the same shape
        if input1.meta["val"].shape != input2.meta["val"].shape:
            return False

        # Check if one of the inputs is a zero tensor
        if input1.target == exir_ops.edge.aten.full.default:
            if input1.args[1] == 0:
                node.replace_all_uses_with(input2)
                return True
        elif input2.target == exir_ops.edge.aten.full.default:
            if input2.args[1] == 0:
                node.replace_all_uses_with(input1)
                return True

        return False


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemovePermutesAroundElementwiseOps(ExportPass):
    """
    Looks for subgraphs of elementwise ops sandwiched between permutes and removes those
    permutes if possible.
    Allows special handling for certain non-elementwise ops that can be easily updated
    based on the permute's parameter such as mean, cat, and slice.
    """

    @dataclass()
    class Subgraph:
        start_permute: list[int]
        end_permute: list[int]
        # Nodes in the subgraph, does not include permutes.
        nodes: set[torch.fx.Node] = field(default_factory=set)
        # Incoming edges to the subgraph from permute nodes.
        edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)
        # Outgoing edges of the subgraph to permute nodes.
        edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)

    permutable_ops: set[EdgeOpOverload] = {
        exir_ops.edge.aten.add.Tensor,
        exir_ops.edge.aten.mul.Tensor,
        exir_ops.edge.aten.hardtanh.default,
        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
        exir_ops.edge.cadence.quantize_per_tensor.default,
        exir_ops.edge.cadence.dequantize_per_tensor.default,
        # Ops that require special handling.
        exir_ops.edge.aten.cat.default,
        exir_ops.edge.aten.mean.dim,
        exir_ops.edge.aten.slice_copy.Tensor,
    }

    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = []
        processed_nodes: set[torch.fx.Node] = set()
        for node in graph_module.graph.find_nodes(
            op="call_function", target=exir_ops.edge.aten.permute_copy.default
        ):
            start_permute = self.get_permutation(node)
            # Expected end permutation for the subgraph.
            end_permute = [start_permute.index(i) for i in range(len(start_permute))]

            for user in node.users:
                if user.target not in self.permutable_ops:
                    continue
                # Create a separate subgraph for each user since there may be cases
                # where only a portion of the users are permutable.
                subgraph = self.Subgraph(start_permute, end_permute)
                if self.visit(user, subgraph, processed_nodes):
                    subgraphs_found.append(subgraph)
                    for node in subgraph.nodes:
                        processed_nodes.add(node)

        modified = False
        for subgraph in subgraphs_found:
            self.permute_subgraph(subgraph)
            modified = True

        if modified:
            graph_module.graph.eliminate_dead_code()
            graph_module.recompile()
            return super().call(graph_module)

        return PassResult(graph_module, False)

    def visit(
        self,
        node: torch.fx.Node,
        subgraph: Subgraph,
        processed_nodes: set[torch.fx.Node],
    ) -> bool:
        if node in subgraph.nodes:
            return True
        if node in processed_nodes or not self.is_node_permutable(node):
            return False
        subgraph.nodes.add(node)

        # Traverse downstream:
        for user in node.users:
            # Output should either go to a matching permute or another permutable op.
            if user.target == exir_ops.edge.aten.permute_copy.default:
                if self.get_permutation(user) != subgraph.end_permute:
                    return False
                subgraph.edges_out.add((node, user))
            elif not self.visit(user, subgraph, processed_nodes):
                return False

        # Traverse upstream:
        for inp in node.all_input_nodes:
            # Input should either come from a matching permute or another permutable op.
            if inp.target == exir_ops.edge.aten.permute_copy.default:
                if self.get_permutation(inp) != subgraph.start_permute:
                    return False
                subgraph.edges_in.add((inp, node))
            elif not self.visit(inp, subgraph, processed_nodes):
                return False

        return True

    def is_node_permutable(self, node: torch.fx.Node) -> bool:
        if node.target not in self.permutable_ops:
            return False
        if node.target == exir_ops.edge.aten.mean.dim:
            # keepdim should be True.
            if len(node.args) >= 3:
                if not node.args[2]:
                    return False
            elif "keepdim" in node.kwargs:
                if not node.kwargs["keepdim"]:
                    return False
            else:
                # Default keepdim is False.
                return False
        return True

    def permute_subgraph(self, subgraph: Subgraph) -> None:
        # Skip incoming permutes.
        for inp, out in subgraph.edges_in:
            assert inp.target == exir_ops.edge.aten.permute_copy.default
            if len(inp.args) >= 1:
                out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0]))
            else:
                out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"]))

        # Skip outgoing permutes.
        for inp, out in subgraph.edges_out:
            assert out.target == exir_ops.edge.aten.permute_copy.default
            out.replace_all_uses_with(inp)

        # Handle dimension related node arguments.
        for node in subgraph.nodes:
            if node.target == exir_ops.edge.aten.cat.default:
                self.update_cat(node, subgraph.start_permute)
            elif node.target == exir_ops.edge.aten.mean.dim:
                self.update_mean_dim(node, subgraph.start_permute)
            elif node.target == exir_ops.edge.aten.slice_copy.Tensor:
                self.update_slice_copy(node, subgraph.start_permute)

    def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
        if len(node.args) >= 2:
            node.update_arg(1, start_permute[cast(int, node.args[1])])
        elif "dim" in node.kwargs:
            node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])
        else:
            # Default cat dim is 0.
            node.update_kwarg("dim", start_permute[0])

    def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None:
        if len(node.args) >= 2:
            node.update_arg(
                1, [start_permute[dim] for dim in cast(list[int], node.args[1])]
            )
        else:
            node.update_kwarg(
                "dim",
                [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])],
            )

    def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None:
        if len(node.args) >= 2:
            node.update_arg(1, start_permute[cast(int, node.args[1])])
        else:
            node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])

    def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
        assert permute_node.target == exir_ops.edge.aten.permute_copy.default
        if len(permute_node.args) >= 2:
            return cast(list[int], permute_node.args[1])
        assert "dim" in permute_node.kwargs
        return cast(list[int], permute_node.kwargs["dim"])


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemoveSqueezeViewBeforeElementwiseOps(ExportPass):
    """
    Looks for subgraphs of the form:
    squeeze -> [elementwise ops] -> view
    and removes the squeeze node by reshaping the intermediate ops. If the final view
    is a corresponding unsqueeze it should also get eliminated by noop view elimination
    later. Only handles simple chain of intermediates now.

    The pass works on view ops instead of squeeze directly, thus it should be run after
    the squeeze/unsqueeze->view lowering.
    """

    intermediate_ops: set[EdgeOpOverload] = {
        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
        exir_ops.edge.cadence.quantize_per_tensor.default,
        exir_ops.edge.cadence.dequantize_per_tensor.default,
        # Ops that require special handling:
        exir_ops.edge.aten.slice_copy.Tensor,
    }

    def get_squeeze_indices(self, view_node: Node) -> List[int]:
        """
        Returns the indices of the input dimensions that are squeezed in the output if
        view node is a squeeze. Returns an empty list otherwise.
        """
        input_node = get_arg(view_node, "input", Node)
        input_shape = input_node.meta["val"].shape
        output_shape = view_node.meta["val"].shape

        if len(input_shape) <= len(output_shape):
            return []

        squeeze_indices = []
        out_idx = 0
        for idx, dim in enumerate(input_shape):
            if out_idx >= len(output_shape):
                return []
            if dim == output_shape[out_idx]:
                out_idx += 1
            else:
                # If there's a mismatch between the input and output dimensions, input
                # dimension has to be 1.
                if dim == 1:
                    squeeze_indices.append(idx)
                else:
                    return []

        # Check if all the output dimensions are consumed.
        if out_idx != len(output_shape):
            return []

        return squeeze_indices

    def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> bool:
        if view_node in visited_view_nodes:
            return False

        squeeze_indices = self.get_squeeze_indices(view_node)
        if not squeeze_indices:
            return False

        # Only handle simple chains for now.
        if len(view_node.users) != 1:
            return False
        node = next(iter(view_node.users))

        # Traverse down from the node until finding another view op.
        intermediate_slices = []
        while node.target != exir_ops.edge.aten.view_copy.default:
            # Only handle simple chains for now
            if len(node.users) != 1:
                return False
            if node.target not in self.intermediate_ops:
                return False
            if node.target == exir_ops.edge.aten.slice_copy.Tensor:
                intermediate_slices.append(node)
            node = next(iter(node.users))

        # View node found. We can't optimize this view_node again since the
        # input shape is invalid now so add it to the visited set.
        visited_view_nodes.add(node)

        # Update the intermediate slices.
        for slice_node in intermediate_slices:
            slice_rank = len(slice_node.meta["val"].shape)
            slice_dim = get_arg(slice_node, "dim", int)
            if slice_dim < 0:
                slice_dim += slice_rank
            for squeeze_dim in squeeze_indices:
                if slice_dim >= squeeze_dim:
                    slice_dim += 1
            set_arg(slice_node, "dim", slice_dim)

        # Skip the initial view node.
        input_node = get_arg(view_node, "input", Node)
        view_node.replace_all_uses_with(input_node)
        return True

    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        visited_view_nodes = set()
        modified = False
        for view_node in graph_module.graph.find_nodes(
            op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True
        ):
            modified |= self.handle_squeeze(view_node, visited_view_nodes)

        if modified:
            graph_module.graph.eliminate_dead_code()
            graph_module.recompile()
            return super().call(graph_module)

        return PassResult(graph_module, False)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveBranchedQuantDequant(ExportPass):
    """
    This pass looks for adjacent quant and dequant nodes with identical
    parameters, where the quant node has other users in addition to the
    dequant. The quant and dequant pair would be removed by the
    FuseQuantDequantToRequantizePass if not for the multiple users. This pass
    removes just the dequant node by connecting it to the quant's parent node
    """

    quantize_op_packets: set[EdgeOpOverloadPacket] = {
        exir_ops.edge.cadence.quantize_per_tensor,
        exir_ops.edge.quantized_decomposed.quantize_per_tensor,
    }
    dequantize_op_packets: set[EdgeOpOverloadPacket] = {
        exir_ops.edge.cadence.dequantize_per_tensor,
        exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
    }

    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        modified = self.remove_branched(
            graph_module, self.quantize_op_packets, self.dequantize_op_packets
        )
        modified |= self.remove_branched(
            graph_module, self.dequantize_op_packets, self.quantize_op_packets
        )

        if modified:
            graph_module.graph.eliminate_dead_code()
            result = super().call(graph_module)
            return result

        return PassResult(graph_module, False)

    def remove_branched(
        self,
        graph_module: torch.fx.GraphModule,
        producer_pkts: set[EdgeOpOverloadPacket],
        consumer_pkts: set[EdgeOpOverloadPacket],
    ) -> bool:
        modified = False
        for node in graph_module.graph.nodes:
            if (
                node.op != "call_function"
                or not isinstance(node.target, EdgeOpOverload)
                or get_edge_overload_packet(node.target) not in producer_pkts
            ):
                continue

            if len(node.users) < 2:
                continue

            for user in node.users:
                if (
                    not isinstance(user.target, EdgeOpOverload)
                    or get_edge_overload_packet(user.target) not in consumer_pkts
                ):
                    continue

                # check qparams match
                if node.args[1:] != user.args[1:]:
                    continue

                user.replace_all_uses_with(node.args[0])
                modified = True

        return modified


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveCatFromSliceCopyPass(RemoveOrReplacePassInterface):
    """
    Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
    to the slice_copy.
    """

    @property
    def targets(self) -> list[EdgeOpOverload]:
        return [exir_ops.edge.aten.slice_copy.Tensor]

    def maybe_remove_or_replace(self, node: Node) -> bool:
        cat_node = get_arg(node, "input", Node)
        slice_dim = get_arg(node, "dim", int)
        start_idx = get_arg(node, "start", Optional[int])
        end_idx = get_arg(node, "end", Optional[int])
        step = get_arg(node, "step", int)

        if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
            return False

        # Make sure cat and slice happens on the same dimension.
        cat_dim = get_arg(cat_node, "dim", int)
        if cat_dim != slice_dim:
            return False

        # Canonicalize slice indices.
        cat_output_shape = cat_node.meta["val"].shape
        if start_idx is None:
            start_idx = 0
        elif start_idx < 0:
            start_idx += cat_output_shape[cat_dim]
        if end_idx is None or end_idx > cat_output_shape[cat_dim]:
            end_idx = cat_output_shape[cat_dim]
        elif end_idx < 0:
            end_idx += cat_output_shape[cat_dim]

        offset = 0
        for cat_input_node in get_arg(cat_node, "tensors", Sequence[Node]):
            cat_input_shape = cat_input_node.meta["val"].shape

            # Check if the slice range overlaps with the cat input range.
            if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
                node.replace_input_with(cat_node, cat_input_node)
                set_arg(node, "start", start_idx - offset)
                set_arg(node, "end", end_idx - offset)
                return True

            offset += cat_input_shape[cat_dim]

        return False


class CommonRemovePasses:
    passes: List[Type[ExportPass]] = [
        RemoveAliasCopyOpPass,
        RemoveNopExpandOpPass,
        RemoveNopSliceOrViewOpPass,
        RemoveToOpsPass,
        RemoveZeroSizedCatArgsPass,
        RemovePermutesAroundElementwiseOps,
        RemoveSqueezeViewBeforeElementwiseOps,
        RemoveCatFromSliceCopyPass,
        RemoveCloneOpsTransformImported,
    ]


class CadenceRemoveNops:
    passes: List[Type[ExportPass]] = CommonRemovePasses.passes + [
        SimplifySliceOpPass,
        RemoveNopRequantizeOpPass,
        RemoveZeroSizedConstantPadNd,
        RemoveContiguousOpPass,
        RemoveNopMulOpPass,
        RemoveNopAddOpPass,
        RemoveNopLinalgVectorNormOpPass,
        RemoveBranchedQuantDequant,
    ]
