# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, cast, Dict, List, Tuple

import torch
from executorch.backends.cadence.aot.compiler_utils import get_shape
from executorch.backends.cadence.aot.quantizer.patterns import (
    AddmmPattern,
    AddPattern,
    BmmPattern,
    CatPattern,
    Conv1dPattern,
    Conv1dReluPattern0,
    Conv1dReluPattern1,
    Conv2dPattern,
    Conv2dReluPattern0,
    Conv2dReluPattern1,
    LayerNormPattern,
    LinearPattern,
    MatmulPattern,
    MixedW8A32ConvPattern,
    MixedW8A32GruPattern,
    MixedW8A32LinearPattern,
    ReluPattern0,
    ReluPattern1,
    SoftmaxPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
    check_out_zero_point_is_min_range,
    copy_node_metadata,
    create_zero_bias_int32,
    find_sequential_partitions_aten,
    get_conv_args,
    quantize_tensor_multiplier,
)
from executorch.exir.pass_base import ExportPass
from torch import fx
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.fuser_utils import legalize_graph


# Use this to avoid pyre errors
# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`.
ArgsType = Any

# Use this part for patterns with multiple aten ops
ReluPatterns = (ReluPattern0, ReluPattern1)
ConvPatterns = (Conv1dPattern, Conv2dPattern)
ConvReluPatterns = (
    Conv1dReluPattern0,
    Conv1dReluPattern1,
    Conv2dReluPattern0,
    Conv2dReluPattern1,
)


def get_args_and_kwargs_add(
    graph_module: GraphModule,
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
    X_scale = dequants_inputs[0].args[1]

    X_zero_point = dequants_inputs[0].args[2]
    Y_scale = dequants_inputs[1].args[1]
    Y_zero_point = dequants_inputs[1].args[2]
    args = (
        inputs_inputs[0],
        X_scale,
        X_zero_point,
        inputs_inputs[1],
        Y_scale,
        Y_zero_point,
        quant_node.args[1],
        quant_node.args[2],
    )

    kwargs = {}
    return args, kwargs


# Helper function to get the args and kwargs for the linear replacement op
def get_args_and_kwargs_linear(
    graph_module: GraphModule,
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    weights_inputs: List[fx.Node],
    dequants_weights: List[fx.Node],
    bias_inputs: List[fx.Node],
    quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
    """
    Returns the args and kwargs for the linear replacement op.
    """
    weight_scale = dequants_weights[0].args[1]
    # pyre-fixme[58]: Unsupported operand types
    bias_scale = dequants_inputs[0].args[1] * weight_scale
    requantize_scale = bias_scale / quant_node.args[1]
    requantize_scale_t = torch.tensor([requantize_scale])

    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

    # If bias is not available, create a bias tensor with the shape of weight[0]
    if not bias_inputs:
        weight_node = dequants_weights[0].args[0]
        assert isinstance(weight_node, fx.Node)
        bias = create_zero_bias_int32(graph_module, weight_node, bias_scale)
    else:
        bias = bias_inputs[0]

    args = tuple(inputs_inputs + weights_inputs + [bias])
    kwargs = {
        "src_zero_point": dequants_inputs[0].args[2],
        "weight_zero_point": dequants_weights[0].args[2],
        "out_multiplier": out_multiplier[0].item(),
        "out_shift": out_shift[0].item(),
        "out_zero_point": quant_node.args[2],
        "offset": None,
    }
    return args, kwargs


# Helper function to get the args and kwargs for the layer norm replacement op
def get_args_and_kwargs_layer_norm(
    graph_module: GraphModule,
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    other_inputs: List[fx.Node],
    quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
    """
    Returns the args and kwargs for the layer norm replacement op.
    """
    # Check if the input is per-channel quantized
    # TODO(matthiascremon): add proper support and testing for per-channel quantization
    assert isinstance(dequants_inputs[0].args[1], float) and isinstance(
        dequants_inputs[0].args[2], int
    ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"

    # Make the scale and zero_point tensors
    scale = dequants_inputs[0].args[1]
    zero_point = dequants_inputs[0].args[2]

    weight = other_inputs[1] if len(other_inputs) > 1 else None

    if not weight:
        weight = graph_module.graph.call_function(
            torch.ops.aten.full.default,
            (
                other_inputs[0],
                1,
            ),
            {"dtype": torch.float32},
        )
        assert (
            len(inputs_inputs) == 1
        ), f"Expected 1 input for layer norm weight, got {len(inputs_inputs)}"
        assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
        fake_mode = inputs_inputs[0].meta["val"].fake_mode
        assert fake_mode is not None, "fake_mode is None on input node"
        with fake_mode:
            weight.meta["val"] = torch.full(other_inputs[0], 1, dtype=torch.float32)
        copy_node_metadata(weight, inputs_inputs[0])

    bias = other_inputs[2] if len(other_inputs) > 2 else None

    if not bias:
        bias = graph_module.graph.call_function(
            torch.ops.aten.full.default,
            (
                other_inputs[0],
                0,
            ),
            {"dtype": torch.float32},
        )
        assert (
            len(inputs_inputs) == 1
        ), f"Expected 1 input for layer norm bias, got {len(inputs_inputs)}"
        assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
        fake_mode = inputs_inputs[0].meta["val"].fake_mode
        assert fake_mode is not None, "fake_mode is None on input node"
        with fake_mode:
            bias.meta["val"] = torch.full(other_inputs[0], 0, dtype=torch.float32)
        copy_node_metadata(bias, inputs_inputs[0])

    # Make the args and kwargs for the replacement op
    args = tuple(inputs_inputs + [scale, zero_point])
    kwargs = {
        "normalized_shape": other_inputs[0],
        "weight": weight,
        "bias": bias,
        "eps": 1e-05,
        "output_scale": quant_node.args[1],
        "output_zero_point": quant_node.args[2],
    }
    return args, kwargs


def get_args_and_kwargs_matmul(
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
    requantize_scale = (
        # pyre-ignore[58]: Unsupported operand
        dequants_inputs[0].args[1]
        * dequants_inputs[1].args[1]
    ) / quant_node.args[1]
    requantize_scale_t = torch.tensor([requantize_scale])

    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

    args = (
        inputs_inputs[0],
        dequants_inputs[0].args[2],
        inputs_inputs[1],
        dequants_inputs[1].args[2],
        None,
    )

    kwargs = {
        "out_multiplier": out_multiplier[0].item(),
        "out_shift": out_shift[0].item(),
        "out_zero_point": quant_node.args[2],
        "transposed": False,
    }
    return args, kwargs


def get_args_and_kwargs_cat(
    inputs_inputs: List[fx.Node], other_inputs: List[fx.Node], op_node: fx.Node
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
    args = tuple([inputs_inputs] + other_inputs)
    dim = op_node.args[1] if len(op_node.args) > 1 else 0
    # pyre-fixme[6]: Incompatible parameter type
    kwargs = {"dim": int(dim)}
    return args, kwargs


def get_args_and_kwargs_conv(
    graph_module: GraphModule,
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    weights_inputs: List[fx.Node],
    dequants_weights: List[fx.Node],
    bias_inputs: List[fx.Node],
    quant_node: fx.Node,
    op_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
    weight_scale = dequants_weights[0].args[1]
    weight_zero_point = dequants_weights[0].args[2]
    # pyre-fixme[58]: Unsupported operand types
    bias_scale = dequants_inputs[0].args[1] * weight_scale
    stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1)
    padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0)
    dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1)
    groups = 1 if len(op_node.args) < 7 else op_node.args[6]

    # If bias is not available, create a bias tensor with the shape of weight[0]
    if not bias_inputs:
        weight_node = dequants_weights[0].args[0]
        assert isinstance(weight_node, fx.Node)
        bias = create_zero_bias_int32(graph_module, weight_node, bias_scale)
    else:
        bias = bias_inputs[0]

    # Compute the out multiplier and out shift. They are used when the conv op is
    # replaced by quantized linear, we compute them a priori for simplicity but
    # may revisit the decision.
    requantize_scale = bias_scale / quant_node.args[1]
    requantize_scale_t = torch.tensor([requantize_scale])

    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

    # Make the args and kwargs for the replacement op
    args = tuple(inputs_inputs + weights_inputs + [bias])
    kwargs = {
        "stride": stride,
        "padding": padding,
        "dilation": dilation,
        "groups": groups,
        "input_zero_point": dequants_inputs[0].args[2],
        "weight_zero_point": weight_zero_point,
        "bias_scale": bias_scale,
        "out_scale": quant_node.args[1],
        "out_zero_point": quant_node.args[2],
        "out_multiplier": out_multiplier[0].item(),
        "out_shift": out_shift[0].item(),
    }
    return args, kwargs


def get_args_and_kwargs_relu(
    graph_module: GraphModule,
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
    input_scale = dequants_inputs[0].args[1]
    # pyre-fixme[58]: Unsupported operand types
    requantize_scale = input_scale / quant_node.args[1]
    requantize_scale_t = torch.tensor([requantize_scale])

    (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

    # Make the args and kwargs for the replacement op
    args = tuple(inputs_inputs)

    kwargs = {
        "X_zero_point": dequants_inputs[0].args[2],
        "out_zero_point": quant_node.args[2],
        "out_multiplier": out_multiplier[0].item(),
        "out_shift": out_shift[0].item(),
    }
    return args, kwargs


def get_args_and_kwargs_mixed_w8a32_linear(
    graph_module: GraphModule,
    other_inputs: List[fx.Node],
    weights_inputs: List[fx.Node],
    dequants_weights: List[fx.Node],
    bias_inputs: List[fx.Node],
    dequants_biases: List[fx.Node],
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
    w_scale_ = dequants_weights[0].args[1]
    b_scale_ = dequants_biases[0].args[1]

    args = (
        other_inputs[0],
        weights_inputs[0],
        w_scale_,
        bias_inputs[0],
        b_scale_,
    )
    kwargs = {}

    return args, kwargs


def get_args_and_kwargs_softmax(
    graph_module: GraphModule,
    inputs_inputs: List[fx.Node],
    dequants_inputs: List[fx.Node],
    quant_node: fx.Node,
    op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
    # Make a dummy mask tensor
    mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
    mask_shape = list(mask_shape) if mask_shape else []
    mask_shape[-1] = mask_shape[-1] // 16
    mask_tensor = graph_module.graph.call_function(
        torch.ops.aten.full.default,
        (
            mask_shape,
            0.0,
        ),
        {"dtype": torch.int32},
    )
    assert (
        len(inputs_inputs) == 1
    ), f"Expected 1 input for softmax, got {len(inputs_inputs)}"
    assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
    fake_mode = inputs_inputs[0].meta["val"].fake_mode
    assert fake_mode is not None, "fake_mode is None on input node"
    with fake_mode:
        mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
    copy_node_metadata(mask_tensor, inputs_inputs[0])
    # Make the scale and zero_point tensors
    in_scale = dequants_inputs[0].args[1]
    in_zero_point = dequants_inputs[0].args[2]
    out_scale = quant_node.args[1]
    out_zero_point = quant_node.args[2]

    # Make the args and kwargs for the replacement op
    args = (
        inputs_inputs[0],
        mask_tensor,
        op_node.args[1],
        in_scale,
        in_zero_point,
        out_scale,
        out_zero_point,
    )
    kwargs = {}

    return args, kwargs


def get_args_and_kwargs_mixed_w8a32_conv(
    graph_module: GraphModule,
    other_inputs: List[fx.Node],
    weights_inputs: List[fx.Node],
    dequants_weights: List[fx.Node],
    bias_inputs: List[fx.Node],
    dequants_biases: List[fx.Node],
    op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
    # Stride, padding, dilation, groups not supported yet
    if len(op_node.args) > 3:
        assert op_node.args[3] == [1]  # Stride
    if len(op_node.args) > 4:
        assert op_node.args[4] == [0]  # Padding
    if len(op_node.args) > 5:
        assert op_node.args[5] == [1]  # Dilation
    if len(op_node.args) > 6:
        assert op_node.args[6] == 1  # Groups

    assert len(dequants_weights) == 1
    assert len(dequants_biases) == 1
    W_scale_ = dequants_weights[0].args[1]
    B_scale_ = dequants_biases[0].args[1]

    transposed_inputs = graph_module.graph.call_function(
        torch.ops.aten.permute.default,
        (other_inputs[0], [0, 2, 1]),  # NCL -> NLC
    )
    assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
    original_val = other_inputs[0].meta["val"]
    assert original_val.fake_mode is not None, "fake_mode is None on input node"
    with original_val.fake_mode:
        transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
            original_val, [0, 2, 1]
        )
    copy_node_metadata(transposed_inputs, other_inputs[0])

    transposed_weights = graph_module.graph.call_function(
        torch.ops.aten.permute.default,
        (weights_inputs[0], [2, 0, 1]),  # NCL -> LNC
    )
    assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
    original_val = weights_inputs[0].meta["val"]
    assert original_val.fake_mode is not None, "fake_mode is None on weight node"
    with original_val.fake_mode:
        transposed_weights.meta["val"] = torch.ops.aten.permute.default(
            original_val, [2, 0, 1]
        )
    copy_node_metadata(transposed_weights, weights_inputs[0])

    args = (
        transposed_inputs,
        transposed_weights,
        W_scale_,
        bias_inputs[0],
        B_scale_,
    )
    kwargs = {}

    return args, kwargs


def get_args_and_kwargs_mixed_w8a32_gru(
    graph_module: GraphModule,
    other_inputs: List[fx.Node],
    weights_inputs: List[fx.Node],
    dequants_weights: List[fx.Node],
    bias_inputs: List[fx.Node],
    dequants_biases: List[fx.Node],
    op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
    # Stride, padding, dilation, groups not supported yet

    assert len(dequants_weights) == 2
    assert len(dequants_biases) == 2
    w_i_scale = dequants_weights[0].args[1]
    w_h_scale = dequants_weights[1].args[1]
    b_i_scale = dequants_biases[0].args[1]
    b_h_scale = dequants_biases[1].args[1]

    args = (
        other_inputs[0],
        other_inputs[1],
        weights_inputs[0],
        w_i_scale,
        weights_inputs[1],
        w_h_scale,
        bias_inputs[0],
        b_i_scale,
        bias_inputs[1],
        b_h_scale,
    )
    kwargs = {}

    return args, kwargs


class QuantFusion(ExportPass):
    # pyre-ignore[2]: Parameter `patterns` has no type specified
    def __init__(self, patterns) -> None:
        super().__init__()
        # pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified
        self.patterns = patterns

    def call(self, graph_module: fx.GraphModule) -> PassResult:  # noqa: C901
        for pattern in self.patterns:
            fused_partitions = find_sequential_partitions_aten(
                graph_module,
                pattern.partition_types(),
            )
            for fused_partition in fused_partitions:
                anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
                if not anchors or anchors.empty:
                    continue
                if any(self.is_fused(p.nodes) for p in fused_partition):
                    continue

                for p in fused_partition:
                    self.mark_fused(p.nodes)

                dequants_inputs = []
                for node, idx, *_spec in anchors.inputs:
                    arg = (
                        node.args[idx]
                        if isinstance(idx, int)
                        else node.args[idx[0]][idx[1]]
                    )
                    if (
                        arg.target
                        == torch.ops.quantized_decomposed.dequantize_per_tensor.default
                    ):
                        dequants_inputs.append(arg)
                dequants_weights = []
                for node, idx in anchors.weights:
                    if (
                        node.args[idx].target
                        == torch.ops.quantized_decomposed.dequantize_per_tensor.default
                    ):
                        dequants_weights.append(node.args[idx])
                dequants_biases = []
                for node, idx, *_spec in anchors.biases:
                    if (
                        node.args[idx].target
                        == torch.ops.quantized_decomposed.dequantize_per_tensor.default
                    ):
                        dequants_biases.append(node.args[idx])

                inputs_inputs = [node.args[0] for node in dequants_inputs]
                weights_inputs = [node.args[0] for node in dequants_weights]
                bias_inputs = [node.args[0] for node in dequants_biases]
                other_inputs = [node.args[idx] for node, idx in anchors.others]

                assert op_node is not None, "op_node is None"
                quant_node = list(op_node.users.keys())[0]

                with graph_module.graph.inserting_after(op_node):
                    args = tuple(
                        inputs_inputs + weights_inputs + other_inputs + bias_inputs
                    )
                    kwargs = {}
                    if isinstance(pattern, AddPattern):
                        args, kwargs = get_args_and_kwargs_add(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            quant_node,
                        )
                    elif isinstance(pattern, CatPattern):
                        args, kwargs = get_args_and_kwargs_cat(
                            inputs_inputs, other_inputs, op_node
                        )
                    elif isinstance(pattern, ConvReluPatterns):
                        # For ConvReLU, we are fusing Conv+ReLU
                        # This means that the op we want to get
                        # the replacement args and kwargs for is the
                        # *conv* op, which is the anchor input, NOT
                        # the anchor output (which is the ReLU)
                        check_out_zero_point_is_min_range(
                            quant_node.args[2], quant_node.args[5]
                        )
                        anchor_input_node = anchors.inputs[0][0]
                        args, kwargs = get_args_and_kwargs_conv(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            weights_inputs,
                            dequants_weights,
                            bias_inputs,
                            quant_node,
                            anchor_input_node,
                        )
                    elif isinstance(pattern, ConvPatterns):
                        args, kwargs = get_args_and_kwargs_conv(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            weights_inputs,
                            dequants_weights,
                            bias_inputs,
                            quant_node,
                            op_node,
                        )
                    elif isinstance(pattern, LinearPattern):
                        args, kwargs = get_args_and_kwargs_linear(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            weights_inputs,
                            dequants_weights,
                            bias_inputs,
                            quant_node,
                        )
                    elif isinstance(pattern, LayerNormPattern):
                        args, kwargs = get_args_and_kwargs_layer_norm(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            other_inputs,
                            quant_node,
                        )
                    elif isinstance(pattern, (BmmPattern, MatmulPattern)):
                        args, kwargs = get_args_and_kwargs_matmul(
                            inputs_inputs,
                            dequants_inputs,
                            quant_node,
                        )
                    elif isinstance(pattern, AddmmPattern):
                        # Transpose the weight tensor
                        transposed_weights = graph_module.graph.call_function(
                            torch.ops.aten.transpose.int,
                            (weights_inputs[0], 0, 1),
                        )
                        assert (
                            "val" in weights_inputs[0].meta
                        ), "Missing val metadata on weight node"
                        original_val = weights_inputs[0].meta["val"]
                        assert (
                            original_val.fake_mode is not None
                        ), "fake_mode is None on weight node"
                        with original_val.fake_mode:
                            transposed_weights.meta["val"] = (
                                torch.ops.aten.transpose.int(original_val, 0, 1)
                            )
                        copy_node_metadata(transposed_weights, weights_inputs[0])

                        # Call linear with transposed weight
                        args, kwargs = get_args_and_kwargs_linear(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            [transposed_weights],
                            dequants_weights,
                            bias_inputs,
                            quant_node,
                        )
                    elif isinstance(pattern, ReluPatterns):
                        args, kwargs = get_args_and_kwargs_relu(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            quant_node,
                        )
                    elif isinstance(pattern, SoftmaxPattern):
                        args, kwargs = get_args_and_kwargs_softmax(
                            graph_module,
                            inputs_inputs,
                            dequants_inputs,
                            quant_node,
                            op_node,
                        )
                    elif isinstance(pattern, MixedW8A32LinearPattern):
                        args, kwargs = get_args_and_kwargs_mixed_w8a32_linear(
                            graph_module,
                            other_inputs,
                            weights_inputs,
                            dequants_weights,
                            bias_inputs,
                            dequants_biases,
                        )
                    elif isinstance(pattern, MixedW8A32ConvPattern):
                        args, kwargs = get_args_and_kwargs_mixed_w8a32_conv(
                            graph_module,
                            other_inputs,
                            weights_inputs,
                            dequants_weights,
                            bias_inputs,
                            dequants_biases,
                            op_node,
                        )
                    elif isinstance(pattern, MixedW8A32GruPattern):
                        args, kwargs = get_args_and_kwargs_mixed_w8a32_gru(
                            graph_module,
                            other_inputs,
                            weights_inputs,
                            dequants_weights,
                            bias_inputs,
                            dequants_biases,
                            op_node,
                        )

                    fused = graph_module.graph.call_function(
                        pattern.replacement_op(),
                        args,
                        kwargs,
                    )

                    if len(anchors.output) > 0:
                        fused.meta = quant_node.meta
                        quant_node.replace_all_uses_with(fused)
                    else:
                        fused.meta = op_node.meta
                        op_node.replace_all_uses_with(fused)
                        if op_node.op == "output":
                            _ = graph_module.graph.output((fused,))

            legalize_graph(graph_module)
            graph_module.graph.eliminate_dead_code()
            nodes_list = list(graph_module.graph.nodes)

            if len(nodes_list) > 0 and nodes_list[-1].op != "output":
                output_nodes = [n for n in nodes_list if n.op == "output"]
                output_arg = output_nodes[0].args[0]
                original_meta = output_nodes[0].meta.copy()

                for out_node in output_nodes:
                    graph_module.graph.erase_node(out_node)

                new_output_node = graph_module.graph.output(output_arg)
                new_output_node.meta.update(original_meta)

            graph_module.recompile()
        return PassResult(graph_module, True)

    @classmethod
    # pyre-ignore[2]: Parameter `nodes` has no type specified
    def is_fused(cls, nodes) -> bool:
        return any(cls.__qualname__ in n.meta for n in nodes)

    @classmethod
    # pyre-ignore[2]: Parameter `nodes` has no type specified
    def mark_fused(cls, nodes) -> bool:
        for n in nodes:
            # pyre-fixme[7]: Incompatible return type
            n.meta["QuantFusion"] = True
