# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import copy
import os
import random
import statistics
import sys
import tempfile
import unittest
from contextlib import redirect_stdout

from typing import Callable, List, Union

from unittest.mock import MagicMock, patch

import pandas as pd

import torch
import torch.fx
import torch.utils._pytree as pytree

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.devtools import generate_etrecord, parse_etrecord
from executorch.devtools.debug_format.et_schema import OperatorNode
from executorch.devtools.etdump.schema_flatcc import ProfileEvent
from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord

from executorch.devtools.inspector import (
    _inspector,
    Event,
    EventBlock,
    Inspector,
    PerfData,
)
from executorch.devtools.inspector._inspector import (
    DebugEventSignature,
    flatcc,
    InstructionEvent,
    InstructionEventSignature,
    ProfileEventSignature,
    TimeScale,
)
from executorch.devtools.inspector.tests.inspector_test_utils import (
    check_if_debug_handle_to_op_names_match,
    check_if_intermediate_outputs_match,
    model_registry,
)
from executorch.exir import (
    EdgeCompileConfig,
    EdgeProgramManager,
    ExecutorchProgramManager,
    to_edge,
    to_edge_transform_and_lower,
)
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.extension.pybindings.portable_lib import (
    _load_for_executorch_from_buffer,
)
from torch.export import export, ExportedProgram


# Models for testing inplace ops intermediate output logging
class IndexPutModel(torch.nn.Module):
    """
    A model that uses index_put to update a tensor at specific indices.
    When the reinplace_pass is enabled, this will be converted to index_put_
    (the inplace variant), which was causing issues with event tracer logging.
    """

    def __init__(self):
        super().__init__()
        self.register_buffer("data", torch.zeros(5, 3))

    def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        result = self.data.index_put((indices,), values)
        return result.sum()


OP_TYPE = "aten::add"
EVENT_BLOCK_NAME = "block_0"
EVENTS_SIZE = 10
RAW_DATA_SIZE = 10
ETDUMP_PATH = "unittest_etdump_path"
ETRECORD_PATH = "unittest_etrecord_path"


# TODO: write an E2E test: create an inspector instance, mock just the file reads, and then verify the external correctness
class TestInspector(unittest.TestCase):
    def test_perf_data(self) -> None:
        random_floats = self._gen_random_float_list()
        perfData = PerfData(random_floats)

        # Intentionally use a different way to calculate p50 from the implementation
        self.assertAlmostEqual(perfData.p50, statistics.median(random_floats))

    def test_event_block_to_dataframe(self) -> None:
        eventBlock = EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())

        df = eventBlock.to_dataframe()
        # Check some fields of the returned dataframe
        self.assertEqual(len(df), EVENTS_SIZE)
        self.assertTrue("op_0" in df["event_name"].values)
        self.assertEqual(len(df["raw"].values[0]), RAW_DATA_SIZE)
        self.assertEqual(df["op_types"].values[0][0], OP_TYPE)

    def test_inspector_constructor(self):
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ) as mock_parse_etrecord, patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ) as mock_gen_etdump, patch.object(
            EventBlock, "_gen_from_etdump"
        ) as mock_gen_from_etdump, patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ) as mock_gen_graphs_from_etrecord:
            # Call the constructor of Inspector
            Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # Assert that expected functions are called
            mock_parse_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH)
            mock_gen_etdump.assert_called_once_with(
                etdump_path=ETDUMP_PATH, etdump_data=None
            )
            mock_gen_from_etdump.assert_called_once()
            # Because we mocked parse_etrecord() to return None, this method shouldn't be called
            mock_gen_graphs_from_etrecord.assert_not_called()

    def test_default_delegate_time_scale_converter(self):
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ) as mock_gen_from_etdump, patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ), patch.object(
            _inspector, "create_debug_handle_to_op_node_mapping"
        ):
            # Call the constructor of Inspector
            Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
                source_time_scale=TimeScale.US,
                target_time_scale=TimeScale.S,
            )

            # Verify delegate_time_scale_converter is set to be a callable
            self.assertIsInstance(
                mock_gen_from_etdump.call_args.get("delegate_time_scale_converter"),
                Callable,
            )

    def test_inspector_print_data_tabular(self):
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            # Call the constructor of Inspector
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # The mock inspector instance starts with having an empty event blocks list.
            # Add non-empty event blocks to test print_data_tabular().
            inspector_instance.event_blocks = [
                EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
            ]
            # Call print_data_tabular(), make sure it doesn't crash
            with redirect_stdout(None):
                inspector_instance.print_data_tabular()

    def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self):
        # Test on an event with a single debug handle
        debug_handle = 111
        event_with_single_debug_handle = Event(
            name="event_with_single_debug_handle",
            perf_data=PerfData(raw=[]),
            debug_handles=debug_handle,
        )
        node_0 = OperatorNode(
            name="node_0",
            metadata={
                "debug_handle": debug_handle,
                "stack_trace": "stack_trace_relu",
                "nn_module_stack": "module_hierarchy_relu",
            },
            op="op",
        )

        # Call the method that's under testing and verify
        event_with_single_debug_handle._associate_with_op_graph_nodes(
            {
                debug_handle: [
                    node_0,
                ]
            }
        )

        expected_stack_traces = {"node_0": "stack_trace_relu"}
        self.assertEqual(
            event_with_single_debug_handle.stack_traces, expected_stack_traces
        )
        expected_module_hierarchy = {"node_0": "module_hierarchy_relu"}
        self.assertEqual(
            event_with_single_debug_handle.module_hierarchy, expected_module_hierarchy
        )
        expected_ops = ["op"]
        self.assertEqual(event_with_single_debug_handle.op_types, expected_ops)

    def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self):
        # Test on an event with a sequence of debug handles
        debug_handles = [222, 333]
        event_with_multiple_debug_handles = Event(
            name="event_with_multiple_debug_handles",
            perf_data=PerfData(raw=[]),
            debug_handles=debug_handles,
        )
        node_0 = OperatorNode(
            name="node_0",
            metadata={
                "debug_handle": debug_handles[0],
                "stack_trace": "stack_trace_relu",
                "nn_module_stack": "module_hierarchy_relu",
            },
            op="op_0",
        )
        node_1 = OperatorNode(
            name="node_1",
            metadata={
                "debug_handle": debug_handles[1],
                "stack_trace": "stack_trace_conv",
                "nn_module_stack": "module_hierarchy_conv",
            },
            op="op_1",
        )

        # Call the method that's under testing and verify
        event_with_multiple_debug_handles._associate_with_op_graph_nodes(
            {
                debug_handles[0]: [
                    node_0,
                ],
                debug_handles[1]: [
                    node_1,
                ],
            }
        )

        expected_stack_traces = {
            "node_0": "stack_trace_relu",
            "node_1": "stack_trace_conv",
        }
        self.assertEqual(
            event_with_multiple_debug_handles.stack_traces, expected_stack_traces
        )
        expected_module_hierarchy = {
            "node_0": "module_hierarchy_relu",
            "node_1": "module_hierarchy_conv",
        }
        self.assertEqual(
            event_with_multiple_debug_handles.module_hierarchy,
            expected_module_hierarchy,
        )
        expected_ops = ["op_0", "op_1"]
        self.assertEqual(event_with_multiple_debug_handles.op_types, expected_ops)

    def test_inspector_delegate_time_scale_converter(self):
        def time_scale_converter(event_name, time):
            return time / 10

        event = Event(
            name="",
            _delegate_metadata_parser=None,
            _delegate_time_scale_converter=None,
        )
        event_signature = ProfileEventSignature(
            name="",
            instruction_id=0,
            delegate_id_str="test_event",
        )
        instruction_events = [
            InstructionEvent(
                signature=InstructionEventSignature(0, 0),
                profile_events=[
                    ProfileEvent(
                        name="test_event",
                        chain_index=0,
                        instruction_id=0,
                        delegate_debug_id_int=None,
                        delegate_debug_id_str="test_event_delegated",
                        start_time=100,
                        end_time=200,
                        delegate_debug_metadata=None,
                    )
                ],
            )
        ]
        Event._populate_profiling_related_fields(
            event, event_signature, instruction_events, 1
        )
        # Value of the perf data before scaling is done.
        self.assertEqual(event.perf_data.raw[0], 100)
        event._delegate_time_scale_converter = time_scale_converter
        Event._populate_profiling_related_fields(
            event, event_signature, instruction_events, 1
        )
        # Value of the perf data after scaling is done. 200/10 - 100/10.
        self.assertEqual(event.perf_data.raw[0], 10)

    def test_inspector_get_exported_program(self):
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ), patch.object(
            _inspector, "create_debug_handle_to_op_node_mapping"
        ):
            # Call the constructor of Inspector
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # Gen a mock etrecord
            captured_output, edge_output, et_output = TestETRecord().get_test_model()
            with tempfile.TemporaryDirectory() as tmpdirname:
                generate_etrecord(
                    tmpdirname + "/etrecord.bin",
                    edge_output,
                    et_output,
                    extra_recorded_export_modules={
                        "aten_dialect_output": captured_output,
                    },
                )

                inspector_instance._etrecord = parse_etrecord(
                    tmpdirname + "/etrecord.bin"
                )

                self.assertTrue(
                    isinstance(
                        inspector_instance.get_exported_program(), ExportedProgram
                    )
                )

    def test_populate_debugging_related_fields_raises_for_inconsistent_events(self):
        ret_event: Event = Event(
            name="event",
        )

        debug_event_0 = flatcc.DebugEvent(
            name="event",
            chain_index=1,
            instruction_id=0,
            delegate_debug_id_int=1,
            delegate_debug_id_str=None,
            debug_entry=flatcc.Value(
                val=flatcc.ValueType.TENSOR.value,
                tensor=flatcc.Tensor(
                    scalar_type=flatcc.ScalarType.INT,
                    sizes=[2],
                    strides=[1],
                    offset=12345,
                ),
                tensor_list=None,
                int_value=None,
                float_value=None,
                double_value=None,
                bool_value=None,
                output=None,
            ),
        )

        # Note the sizes of this tensor are different from the previous one
        debug_event_1 = flatcc.DebugEvent(
            name="event",
            chain_index=1,
            instruction_id=0,
            delegate_debug_id_int=1,
            delegate_debug_id_str=None,
            debug_entry=flatcc.Value(
                val=flatcc.ValueType.TENSOR.value,
                tensor=flatcc.Tensor(
                    scalar_type=flatcc.ScalarType.INT,
                    sizes=[1],
                    strides=[1],
                    offset=23456,
                ),
                tensor_list=None,
                int_value=None,
                float_value=None,
                double_value=None,
                bool_value=None,
                output=None,
            ),
        )

        instruction_event_0 = InstructionEvent(
            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_0]
        )
        instruction_event_1 = InstructionEvent(
            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_1]
        )

        events = [instruction_event_0, instruction_event_1]

        # Expect AssertionError because 2 tensors have different sizes
        with self.assertRaises(AssertionError):
            Event._populate_debugging_related_fields(
                ret_event=ret_event,
                debug_event_signature=DebugEventSignature(instruction_id=1),
                events=events,
            )

    def test_populate_debugging_related_fields_passes_for_consistent_events(self):
        ret_event: Event = Event(
            name="event",
        )

        debug_event_0 = flatcc.DebugEvent(
            name="event",
            chain_index=1,
            instruction_id=0,
            delegate_debug_id_int=1,
            delegate_debug_id_str=None,
            debug_entry=flatcc.Value(
                val=flatcc.ValueType.TENSOR.value,
                tensor=flatcc.Tensor(
                    scalar_type=flatcc.ScalarType.INT,
                    sizes=[1],
                    strides=[1],
                    offset=12345,
                ),
                tensor_list=None,
                int_value=None,
                float_value=None,
                double_value=None,
                bool_value=None,
                output=None,
            ),
        )

        # Same as the event above except for offset
        debug_event_1 = flatcc.DebugEvent(
            name="event",
            chain_index=1,
            instruction_id=0,
            delegate_debug_id_int=1,
            delegate_debug_id_str=None,
            debug_entry=flatcc.Value(
                val=flatcc.ValueType.TENSOR.value,
                tensor=flatcc.Tensor(
                    scalar_type=flatcc.ScalarType.INT,
                    sizes=[1],
                    strides=[1],
                    offset=23456,
                ),
                tensor_list=None,
                int_value=None,
                float_value=None,
                double_value=None,
                bool_value=None,
                output=None,
            ),
        )

        instruction_event_0 = InstructionEvent(
            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_0]
        )
        instruction_event_1 = InstructionEvent(
            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_1]
        )

        events = [instruction_event_0, instruction_event_1]

        with patch.object(_inspector, "is_inference_output_equal", return_value=True):
            # Expect it runs with no error because is_inference_output_equal() is mocked to return True
            Event._populate_debugging_related_fields(
                ret_event=ret_event,
                debug_event_signature=DebugEventSignature(instruction_id=1),
                events=events,
            )

    def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs(self):
        with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
            etrecord_path = tmp_file.name
            mod = model_registry["ConvLinearModel"]()
            input_tensor = torch.tensor(
                [[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True
            )
            aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True)
            edge_program_manager: EdgeProgramManager = to_edge(
                aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
            )
            edge_program_manager_copy = copy.deepcopy(edge_program_manager)
            et_program_manager: ExecutorchProgramManager = (
                edge_program_manager.to_executorch()
            )
            # Generate ETRecord
            generate_etrecord(
                etrecord_path, edge_program_manager_copy, et_program_manager
            )
            with patch.object(
                Inspector, "_consume_etrecord", return_value=None
            ), patch.object(
                _inspector, "gen_etdump_object", return_value=None
            ), patch.object(
                EventBlock, "_gen_from_etdump"
            ), patch.object(
                _inspector, "gen_graphs_from_etrecord"
            ):
                # Call the constructor of Inspector
                inspector_instance = Inspector(
                    etdump_path=ETDUMP_PATH,
                    etrecord=etrecord_path,
                )

                inspector_instance._etrecord._representative_inputs = (
                    aten_model.example_inputs[0]
                )

                # First resolve the reference graph, then get intermediate outputs
                reference_graph_module, _ = inspector_instance._resolve_reference_graph(
                    "edge_dialect_exported_program"
                )
                aot_intermediate_outputs, aot_debug_handle_to_op_names = (
                    inspector_instance._get_aot_intermediate_outputs_and_op_names(
                        reference_graph_module
                    )
                )
                self.assertTrue(
                    check_if_intermediate_outputs_match(
                        aot_intermediate_outputs,
                        mod.get_edge_dialect_expected_intermediate_outputs(),
                    )
                )

                self.assertTrue(
                    check_if_debug_handle_to_op_names_match(
                        aot_debug_handle_to_op_names,
                        mod.get_edge_dialect_expected_debug_handle_to_op_names(),
                    )
                )

    def test_etrecord_populates_correct_export_program_aot_intermediate_outputs(self):
        with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
            etrecord_path = tmp_file.name
            mod = model_registry["ConvLinearModel"]()
            input_tensor = mod.get_input()
            aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True)
            edge_program_manager: EdgeProgramManager = to_edge(aten_model)
            edge_program_manager_copy = copy.deepcopy(edge_program_manager)
            et_program_manager: ExecutorchProgramManager = (
                edge_program_manager.to_executorch()
            )
            # Generate ETRecord with the exported program
            generate_etrecord(
                etrecord_path,
                edge_program_manager_copy,
                et_program_manager,
                exported_program=aten_model,
            )
            with patch.object(
                Inspector, "_consume_etrecord", return_value=None
            ), patch.object(
                _inspector, "gen_etdump_object", return_value=None
            ), patch.object(
                EventBlock, "_gen_from_etdump"
            ), patch.object(
                _inspector, "gen_graphs_from_etrecord"
            ):
                # Call the constructor of Inspector
                inspector_instance = Inspector(
                    etdump_path=ETDUMP_PATH,
                    etrecord=etrecord_path,
                )

                inspector_instance._etrecord._representative_inputs = (
                    aten_model.example_inputs[0]
                )

                # First resolve the reference graph, then get intermediate outputs
                reference_graph_module, _ = inspector_instance._resolve_reference_graph(
                    "exported_program"
                )
                aot_intermediate_outputs, aot_debug_handle_to_op_names = (
                    inspector_instance._get_aot_intermediate_outputs_and_op_names(
                        reference_graph_module
                    )
                )
                self.assertTrue(
                    check_if_intermediate_outputs_match(
                        aot_intermediate_outputs,
                        mod.get_exported_program_expected_intermediate_outputs(),
                    )
                )
                self.assertTrue(
                    check_if_debug_handle_to_op_names_match(
                        aot_debug_handle_to_op_names,
                        mod.get_exported_program_expected_debug_handle_to_op_names(),
                    )
                )

    def test_get_runtime_intermediate_outputs_and_op_names(self):
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            # Call the constructor of Inspector
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # The mock inspector instance starts with having an empty event blocks list.
            # Add pre-defined event blocks to test _get_runtime_outputs().
            inspector_instance.event_blocks = [
                EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
            ]

            runtime_outputs, op_names = (
                inspector_instance._get_runtime_intermediate_outputs_and_op_names()
            )
            # These outputs and op_names dictionaries should all have 5 keys
            self.assertEqual(
                len(runtime_outputs),
                5,
            )
            self.assertEqual(
                len(op_names),
                5,
            )

            # Check that keys (0,) and (1,) are not in these two dictionaries(skip OPERATOR_CALL and op_types are empty)
            self.assertNotIn((0,), runtime_outputs)
            self.assertNotIn((1,), runtime_outputs)
            self.assertNotIn((0,), op_names)
            self.assertNotIn((1,), op_names)

            # Same debug_handle but different instruction_id, should record the last one
            self.assertIn((4,), runtime_outputs)
            self.assertIn((4,), op_names)
            self.assertTrue(
                torch.allclose(
                    runtime_outputs[(4,)][0][0], torch.tensor([4.0, 5.0, 6.0])
                )
            )
            self.assertEqual(op_names[(4,)], ["op_3"])

            # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
            for key in range(5, 9):
                self.assertIn((key,), runtime_outputs)
                self.assertIn((key,), op_names)

    def test_calculate_numeric_gap(self):
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            # Call the constructor of Inspector
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
                (1,): torch.tensor([4.0, 5.0, 6.0]),
            }

            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([2.0, 1.0, 4.0])], 1),
                (1,): ([torch.tensor([3.0, 6.0, 5.0])], 1),
            }

            aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
            runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

            # Create a mock graph module for _resolve_reference_graph
            mock_graph_module = MagicMock()
            inspector_instance._resolve_reference_graph = (
                lambda ref_graph=None, disable_validation=False: (
                    mock_graph_module,
                    "exported_program",
                )
            )
            inspector_instance._get_aot_intermediate_outputs_and_op_names = (
                lambda reference_graph_module: (
                    aot_intermediate_outputs,
                    aot_debug_handle_to_op_name,
                )
            )
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )
            inspector_instance._get_aot_debug_handle_to_stack_traces = (
                lambda reference_graph_module, resolved_graph_name: {}
            )

            # --- Test 1: MSE comparator

            df = inspector_instance.calculate_numeric_gap(distance="L1")
            self.assertIsInstance(df, pd.DataFrame)
            self.assertEqual(len(df), 2)
            cols = set(df.columns)
            expected_cols = {
                "aot_ops",
                "aot_intermediate_output",
                "runtime_ops",
                "runtime_intermediate_output",
                "gap",
                "stacktraces",
            }
            self.assertEqual(cols, expected_cols)
            for i, row in df.iterrows():
                # Dummpy key to get the expected aot/runtime internmediate outputs
                key = (i,)
                # aot_intermediate_output should equal aot_intermediate_outputs[key]
                self.assertTrue(
                    torch.allclose(
                        row["aot_intermediate_output"],
                        aot_intermediate_outputs[key],
                    )
                )
                # runtime_intermediate_output should equal runtime_intermediate_outputs[key]
                self.assertTrue(
                    torch.allclose(
                        row["runtime_intermediate_output"],
                        runtime_intermediate_outputs[key][0][0],
                    )
                )
                # gap should equal 3.0
                self.assertEqual(row["gap"][0], 3.0)

    def test_calculate_numeric_gap_with_stacktraces(self):
        """Test calculate_numeric_gap includes stacktraces column when stack traces are available."""
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            # Call the constructor of Inspector
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
                (1,): torch.tensor([4.0, 5.0, 6.0]),
            }

            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([2.0, 1.0, 4.0])], 1),
                (1,): ([torch.tensor([3.0, 6.0, 5.0])], 1),
            }

            aot_debug_handle_to_op_name = {(0,): ["op_0"], (1,): ["op_1"]}
            runtime_debug_handle_to_op_name = {(0,): ["op_0"], (1,): ["op_1"]}
            aot_debug_handle_to_stack_traces = {
                (0,): {"op_0": "File 'test.py', line 10\n    x * y"},
                (1,): {"op_1": "File 'test.py', line 15\n    x + y"},
            }

            # Create a mock graph module for _resolve_reference_graph
            mock_graph_module = MagicMock()
            inspector_instance._resolve_reference_graph = (
                lambda ref_graph=None, disable_validation=False: (
                    mock_graph_module,
                    "exported_program",
                )
            )
            inspector_instance._get_aot_intermediate_outputs_and_op_names = (
                lambda reference_graph_module: (
                    aot_intermediate_outputs,
                    aot_debug_handle_to_op_name,
                )
            )
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )
            inspector_instance._get_aot_debug_handle_to_stack_traces = (
                lambda reference_graph_module, resolved_graph_name: (
                    aot_debug_handle_to_stack_traces
                )
            )

            df = inspector_instance.calculate_numeric_gap(distance="L1")
            self.assertIsInstance(df, pd.DataFrame)
            self.assertEqual(len(df), 2)
            cols = set(df.columns)
            expected_cols = {
                "aot_ops",
                "aot_intermediate_output",
                "runtime_ops",
                "runtime_intermediate_output",
                "gap",
                "stacktraces",
            }
            self.assertEqual(cols, expected_cols)

            # Verify stacktraces column contains the expected data
            for i, row in df.iterrows():
                key = (i,)
                expected_stack_traces = aot_debug_handle_to_stack_traces[key]
                self.assertEqual(row["stacktraces"], expected_stack_traces)

    def test_calculate_numeric_gap_with_custom_comparator(self):
        """Test calculate_numeric_gap with a custom NumericalComparatorBase implementation."""
        from executorch.devtools.inspector.numerical_comparator import (
            NumericalComparatorBase,
        )

        # Create a custom comparator that returns the max absolute difference
        class MaxAbsDiffComparator(NumericalComparatorBase):
            def element_compare(self, a, b):
                if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
                    return torch.max(torch.abs(a - b)).item()
                return abs(a - b)

        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            # Call the constructor of Inspector
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
                (1,): torch.tensor([4.0, 5.0, 6.0]),
            }

            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([2.0, 1.0, 5.0])], 1),
                (1,): ([torch.tensor([3.0, 6.0, 5.0])], 1),
            }

            aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
            runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

            # Create a mock graph module for _resolve_reference_graph
            mock_graph_module = MagicMock()
            inspector_instance._resolve_reference_graph = (
                lambda ref_graph=None, disable_validation=False: (
                    mock_graph_module,
                    "exported_program",
                )
            )
            inspector_instance._get_aot_intermediate_outputs_and_op_names = (
                lambda reference_graph_module: (
                    aot_intermediate_outputs,
                    aot_debug_handle_to_op_name,
                )
            )
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )
            inspector_instance._get_aot_debug_handle_to_stack_traces = (
                lambda reference_graph_module, resolved_graph_name: {}
            )

            # Create custom comparator instance
            custom_comparator = MaxAbsDiffComparator()

            # Test with custom comparator
            df = inspector_instance.calculate_numeric_gap(distance=custom_comparator)
            self.assertIsInstance(df, pd.DataFrame)
            self.assertEqual(len(df), 2)
            cols = set(df.columns)
            expected_cols = {
                "aot_ops",
                "aot_intermediate_output",
                "runtime_ops",
                "runtime_intermediate_output",
                "gap",
                "stacktraces",
            }
            self.assertEqual(cols, expected_cols)

            # Verify the custom comparator logic
            # For (0,): max(|[1.0, 2.0, 3.0] - [2.0, 1.0, 5.0]|) = max([1.0, 1.0, 2.0]) = 2.0
            self.assertEqual(df.iloc[0]["gap"][0], 2.0)
            # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0
            self.assertEqual(df.iloc[1]["gap"][0], 1.0)

    def test_calculate_numeric_gap_with_custom_comparator_and_preprocessing(self):
        """Test calculate_numeric_gap with multiple custom comparators sharing the same preprocessing."""
        from executorch.devtools.inspector.numerical_comparator import (
            IntermediateOutputMapping,
            NumericalComparatorBase,
        )
        from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import (
            SNRComparator,
        )

        # Shared preprocessing function that scales runtime tensors by 2x
        def scale_runtime_tensors(
            mapping: IntermediateOutputMapping, scale_factor: float = 2.0
        ) -> IntermediateOutputMapping:
            """Scale runtime tensors by scale_factor before comparison."""
            transformed_mapping = {}
            for (aot_handle, aot_output), (
                runtime_handle,
                runtime_output,
            ) in mapping.items():
                # Scale the runtime output
                if isinstance(runtime_output, torch.Tensor):
                    scaled_runtime_output = runtime_output * scale_factor
                else:
                    scaled_runtime_output = runtime_output
                transformed_mapping[(aot_handle, aot_output)] = (
                    runtime_handle,
                    scaled_runtime_output,
                )
            return transformed_mapping

        # Create a custom MSE comparator with shared preprocessing
        class MSEComparatorWithScaling(NumericalComparatorBase):
            def __init__(self, scale_factor: float = 2.0):
                super().__init__()
                self.scale_factor = scale_factor
                self.preprocessing_called = False

            def preprocessing(
                self, mapping: IntermediateOutputMapping
            ) -> IntermediateOutputMapping:
                """Use the shared preprocessing function."""
                self.preprocessing_called = True
                return scale_runtime_tensors(mapping, self.scale_factor)

            def element_compare(self, a, b) -> float:
                """Compute MSE between two tensors."""
                if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
                    return torch.mean(torch.square(a.float() - b.float())).item()
                return (a - b) ** 2

        # Create an SNR comparator with the same shared preprocessing
        class SNRComparatorWithScaling(SNRComparator):
            def __init__(self, scale_factor: float = 2.0):
                super().__init__()
                self.scale_factor = scale_factor
                self.preprocessing_called = False

            def preprocessing(
                self, mapping: IntermediateOutputMapping
            ) -> IntermediateOutputMapping:
                """Use the shared preprocessing function."""
                self.preprocessing_called = True
                return scale_runtime_tensors(mapping, self.scale_factor)

        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # AOT outputs: [1.0, 2.0, 3.0] and [4.0, 5.0, 6.0]
            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
                (1,): torch.tensor([4.0, 5.0, 6.0]),
            }

            # Runtime outputs: [1.0, 1.0, 1.0] and [2.0, 2.0, 2.0]
            # After 2x scaling: [2.0, 2.0, 2.0] and [4.0, 4.0, 4.0]
            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([1.0, 1.0, 1.0])], 1),
                (1,): ([torch.tensor([2.0, 2.0, 2.0])], 1),
            }

            aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
            runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

            # Create a mock graph module for _resolve_reference_graph
            mock_graph_module = MagicMock()
            inspector_instance._resolve_reference_graph = (
                lambda ref_graph=None, disable_validation=False: (
                    mock_graph_module,
                    "exported_program",
                )
            )
            inspector_instance._get_aot_intermediate_outputs_and_op_names = (
                lambda reference_graph_module: (
                    aot_intermediate_outputs,
                    aot_debug_handle_to_op_name,
                )
            )
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )

            # Test 1: Invalid return type
            inspector_instance._get_aot_debug_handle_to_stack_traces = (
                lambda reference_graph_module, resolved_graph_name: {}
            )

            # --- Test 1: MSE comparator with scaling preprocessing ---
            mse_comparator = MSEComparatorWithScaling(scale_factor=2.0)
            df_mse = inspector_instance.calculate_numeric_gap(
                distance=mse_comparator, reference_graph="NOT_USED_NAME"
            )

            # Verify preprocessing was called
            self.assertTrue(mse_comparator.preprocessing_called)

            # Verify DataFrame structure
            self.assertIsInstance(df_mse, pd.DataFrame)
            self.assertEqual(len(df_mse), 2)
            cols = set(df_mse.columns)
            expected_cols = {
                "aot_ops",
                "aot_intermediate_output",
                "runtime_ops",
                "runtime_intermediate_output",
                "gap",
                "stacktraces",
            }
            self.assertEqual(cols, expected_cols)

            # Verify the MSE comparison after preprocessing
            # For (0,): AOT=[1.0, 2.0, 3.0], Runtime after scaling=[2.0, 2.0, 2.0]
            #   MSE = mean((1-2)^2 + (2-2)^2 + (3-2)^2) = mean(1 + 0 + 1) = 2/3
            expected_mse_gap_0 = (1.0 + 0.0 + 1.0) / 3.0
            self.assertAlmostEqual(
                df_mse.iloc[0]["gap"][0], expected_mse_gap_0, places=5
            )

            # For (1,): AOT=[4.0, 5.0, 6.0], Runtime after scaling=[4.0, 4.0, 4.0]
            #   MSE = mean((4-4)^2 + (5-4)^2 + (6-4)^2) = mean(0 + 1 + 4) = 5/3
            expected_mse_gap_1 = (0.0 + 1.0 + 4.0) / 3.0
            self.assertAlmostEqual(
                df_mse.iloc[1]["gap"][0], expected_mse_gap_1, places=5
            )

            # --- Test 2: SNR comparator with the same scaling preprocessing ---
            snr_comparator = SNRComparatorWithScaling(scale_factor=2.0)
            df_snr = inspector_instance.calculate_numeric_gap(
                distance=snr_comparator, reference_graph="NOT_USED_NAME"
            )

            # Verify preprocessing was called
            self.assertTrue(snr_comparator.preprocessing_called)

            # Verify DataFrame structure
            self.assertIsInstance(df_snr, pd.DataFrame)
            self.assertEqual(len(df_snr), 2)
            self.assertEqual(set(df_snr.columns), expected_cols)

            # Verify the SNR comparison after preprocessing
            # For (0,): AOT=[1.0, 2.0, 3.0], Runtime after scaling=[2.0, 2.0, 2.0]
            #   signal_power = mean([1.0^2, 2.0^2, 3.0^2]) = mean([1, 4, 9]) = 14/3
            #   error = [1.0-2.0, 2.0-2.0, 3.0-2.0] = [-1.0, 0.0, 1.0]
            #   error_power = mean([1.0, 0.0, 1.0]) = 2/3
            #   SNR = 10 * log10(14/3 / (2/3)) = 10 * log10(7) ≈ 8.451
            signal_power_0 = (1.0 + 4.0 + 9.0) / 3.0  # 14/3
            error_power_0 = (1.0 + 0.0 + 1.0) / 3.0  # 2/3
            expected_snr_gap_0 = (
                10 * torch.log10(torch.tensor(signal_power_0 / error_power_0)).item()
            )
            self.assertAlmostEqual(
                df_snr.iloc[0]["gap"][0], expected_snr_gap_0, places=5
            )

            # For (1,): AOT=[4.0, 5.0, 6.0], Runtime after scaling=[4.0, 4.0, 4.0]
            #   signal_power = mean([4.0^2, 5.0^2, 6.0^2]) = mean([16, 25, 36]) = 77/3
            #   error = [4.0-4.0, 5.0-4.0, 6.0-4.0] = [0.0, 1.0, 2.0]
            #   error_power = mean([0.0, 1.0, 4.0]) = 5/3
            #   SNR = 10 * log10(77/3 / (5/3)) = 10 * log10(77/5) ≈ 11.875
            signal_power_1 = (16.0 + 25.0 + 36.0) / 3.0  # 77/3
            error_power_1 = (0.0 + 1.0 + 4.0) / 3.0  # 5/3
            expected_snr_gap_1 = (
                10 * torch.log10(torch.tensor(signal_power_1 / error_power_1)).item()
            )
            self.assertAlmostEqual(
                df_snr.iloc[1]["gap"][0], expected_snr_gap_1, places=5
            )

    def test_calculate_numeric_gap_with_invalid_preprocessing_output(self):
        """Test that invalid preprocessing output raises appropriate errors."""
        from executorch.devtools.inspector.numerical_comparator import (
            NumericalComparatorBase,
        )

        # Test 1: preprocessing returns non-dict
        class NonDictPreprocessingComparator(NumericalComparatorBase):
            def preprocessing(self, mapping):
                return "invalid"  # Should return a dict

            def element_compare(self, a, b) -> float:
                return 0.0

        # Test 2: preprocessing returns dict with invalid key format
        class InvalidKeyFormatComparator(NumericalComparatorBase):
            def preprocessing(self, mapping):
                return {"invalid_key": ((0,), torch.tensor([1.0]))}

            def element_compare(self, a, b) -> float:
                return 0.0

        # Test 3: preprocessing returns dict with invalid debug handle in key
        class InvalidKeyDebugHandleComparator(NumericalComparatorBase):
            def preprocessing(self, mapping):
                return {
                    (("not_int",), torch.tensor([1.0])): ((0,), torch.tensor([1.0]))
                }

            def element_compare(self, a, b) -> float:
                return 0.0

        # Test 4: preprocessing returns dict with invalid value format
        class InvalidValueFormatComparator(NumericalComparatorBase):
            def preprocessing(self, mapping):
                return {((0,), torch.tensor([1.0])): "invalid_value"}

            def element_compare(self, a, b) -> float:
                return 0.0

        # Test 5: preprocessing returns dict with invalid debug handle in value
        class InvalidValueDebugHandleComparator(NumericalComparatorBase):
            def preprocessing(self, mapping):
                return {
                    ((0,), torch.tensor([1.0])): (("not_int",), torch.tensor([1.0]))
                }

            def element_compare(self, a, b) -> float:
                return 0.0

        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
            }
            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([1.0, 1.0, 1.0])], 1),
            }
            aot_debug_handle_to_op_name = {(0,): "op_0"}
            runtime_debug_handle_to_op_name = {(0,): "op_0"}

            # Create a mock graph module for _resolve_reference_graph
            mock_graph_module = MagicMock()
            inspector_instance._resolve_reference_graph = (
                lambda ref_graph=None, disable_validation=False: (
                    mock_graph_module,
                    "exported_program",
                )
            )
            inspector_instance._get_aot_intermediate_outputs_and_op_names = (
                lambda reference_graph_module: (
                    aot_intermediate_outputs,
                    aot_debug_handle_to_op_name,
                )
            )
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )
            inspector_instance._get_aot_debug_handle_to_stack_traces = (
                lambda reference_graph_module, resolved_graph_name: {}
            )

            # Test 1: Invalid return type
            with self.assertRaises(TypeError) as context:
                inspector_instance.calculate_numeric_gap(
                    distance=NonDictPreprocessingComparator()
                )
            self.assertIn("must return a dict", str(context.exception))

            # Test 2: Invalid key format
            with self.assertRaises(ValueError) as context:
                inspector_instance.calculate_numeric_gap(
                    distance=InvalidKeyFormatComparator()
                )
            self.assertIn("Invalid key format", str(context.exception))

            # Test 3: Invalid debug handle in key
            with self.assertRaises(ValueError) as context:
                inspector_instance.calculate_numeric_gap(
                    distance=InvalidKeyDebugHandleComparator()
                )
            self.assertIn("Invalid AOT debug handle", str(context.exception))

            # Test 4: Invalid value format
            with self.assertRaises(ValueError) as context:
                inspector_instance.calculate_numeric_gap(
                    distance=InvalidValueFormatComparator()
                )
            self.assertIn("Invalid value format", str(context.exception))

            # Test 5: Invalid debug handle in value
            with self.assertRaises(ValueError) as context:
                inspector_instance.calculate_numeric_gap(
                    distance=InvalidValueDebugHandleComparator()
                )
            self.assertIn("Invalid runtime debug handle", str(context.exception))

    def test_calculate_numeric_gap_with_reference_graph_name(self):
        """Test calculate_numeric_gap with the reference_graph parameter using a custom graph from graph_map."""
        # Create a context manager to patch functions called by Inspector.__init__
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # Create mock intermediate outputs
            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
                (1,): torch.tensor([4.0, 5.0, 6.0]),
            }
            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1),
                (1,): ([torch.tensor([5.0, 6.0, 7.0])], 1),
            }

            aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
            runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

            # Create a mock graph module for the reference graph
            class MockGraphModule:
                def __init__(self):
                    self.graph = MagicMock()
                    self.graph.nodes = []

                def module(self):
                    return self

            mock_graph_module = MockGraphModule()

            # Create a real ETRecord and set up the graph_map with edge_after_transform
            from executorch.devtools.etrecord import ETRecord

            mock_etrecord = ETRecord()
            mock_etrecord._representative_inputs = torch.tensor([1.0])
            mock_etrecord.exported_program = None
            mock_etrecord.edge_dialect_program = mock_graph_module

            # The code adds "/forward" suffix when looking up, so we need "edge_after_transform/forward"
            mock_etrecord.graph_map = {
                "edge_after_transform/forward": mock_graph_module
            }

            inspector_instance._etrecord = mock_etrecord

            # Mock the runtime intermediate outputs
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )

            # Mock IntermediateOutputCapturer and get_aot_debug_handle_to_op_name_mapping
            # These are called inside _get_aot_intermediate_outputs_and_op_names when using a custom graph
            with patch(
                "executorch.devtools.inspector._inspector.IntermediateOutputCapturer"
            ) as mock_capturer_class, patch(
                "executorch.devtools.inspector._inspector.get_aot_debug_handle_to_op_name_mapping"
            ) as mock_get_mapping:
                mock_capturer = MagicMock()
                mock_capturer.run_and_capture.return_value = aot_intermediate_outputs
                mock_capturer_class.return_value = mock_capturer
                mock_get_mapping.return_value = aot_debug_handle_to_op_name

                # Mock the stack traces method since we don't have a real graph module with stack traces
                inspector_instance._get_aot_debug_handle_to_stack_traces = (
                    lambda reference_graph_module, resolved_graph_name: {}
                )

                # Test with reference_graph parameter (without /forward suffix)
                # The code should automatically add "/forward" when looking up in graph_map
                df = inspector_instance.calculate_numeric_gap(
                    distance="L1",
                    reference_graph="edge_after_transform",
                )

                self.assertIsInstance(df, pd.DataFrame)
                self.assertEqual(len(df), 2)

    def test_calculate_numeric_gap_with_invalid_reference_graph_name(self):
        """Test that calculate_numeric_gap raises ValueError for invalid reference_graph."""
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # Create a real ETRecord with empty graph_map
            from executorch.devtools.etrecord import ETRecord

            mock_etrecord = ETRecord()
            mock_etrecord._representative_inputs = torch.tensor([1.0])
            mock_etrecord.graph_map = {}

            inspector_instance._etrecord = mock_etrecord

            # Test with non-existent reference_graph
            # Since "non_existent_graph" has no "/", it will be looked up as "non_existent_graph/forward"
            with self.assertRaises(ValueError) as context:
                inspector_instance.calculate_numeric_gap(
                    distance="L1",
                    reference_graph="non_existent_graph",
                )
            self.assertIn("not found", str(context.exception))
            self.assertIn("non_existent_graph/forward", str(context.exception))

    def test_calculate_numeric_gap_with_exported_program_name_backprop_failure(self):
        """Test that calculate_numeric_gap raises ValueError when exported_program backpropagation fails."""
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # Create mock graph modules
            class MockGraphModule:
                def __init__(self):
                    self.graph = MagicMock()

                def module(self):
                    return self

            mock_exported_program = MockGraphModule()
            mock_edge_dialect_program = MockGraphModule()

            # Create a real ETRecord with exported_program
            from executorch.devtools.etrecord import ETRecord

            mock_etrecord = ETRecord()
            mock_etrecord._representative_inputs = torch.tensor([1.0])
            mock_etrecord.exported_program = mock_exported_program
            mock_etrecord.edge_dialect_program = mock_edge_dialect_program
            mock_etrecord.export_graph_id = "graph_id"
            mock_etrecord.graph_map = {}

            inspector_instance._etrecord = mock_etrecord

            # Mock propagate_back_debug_handle to return False (backpropagation failure)
            with patch(
                "executorch.devtools.inspector._inspector.propagate_back_debug_handle"
            ) as mock_propagate:
                mock_propagate.return_value = False

                # Test with "exported_program" should raise error when backpropagation fails
                with self.assertRaises(ValueError) as context:
                    inspector_instance.calculate_numeric_gap(
                        distance="L1",
                        reference_graph="exported_program",
                    )
                self.assertIn("Cannot use 'exported_program'", str(context.exception))
                self.assertIn("backpropagation failed", str(context.exception))

    def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self):
        """Test calculate_numeric_gap with edge_dialect_exported_program reference_graph parameter."""
        with patch.object(
            _inspector, "parse_etrecord", return_value=None
        ), patch.object(
            _inspector, "gen_etdump_object", return_value=None
        ), patch.object(
            EventBlock, "_gen_from_etdump"
        ), patch.object(
            _inspector, "gen_graphs_from_etrecord"
        ):
            inspector_instance = Inspector(
                etdump_path=ETDUMP_PATH,
                etrecord=ETRECORD_PATH,
            )

            # Create mock intermediate outputs (same structure as test_calculate_numeric_gap)
            aot_intermediate_outputs = {
                (0,): torch.tensor([1.0, 2.0, 3.0]),
            }
            runtime_intermediate_outputs = {
                (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1),
            }

            aot_debug_handle_to_op_name = {(0,): "op_0"}
            runtime_debug_handle_to_op_name = {(0,): "op_0"}

            # Create a mock graph module for _resolve_reference_graph
            mock_graph_module = MagicMock()
            inspector_instance._resolve_reference_graph = (
                lambda ref_graph=None, disable_validation=False: (
                    mock_graph_module,
                    "exported_program",
                )
            )
            inspector_instance._get_aot_intermediate_outputs_and_op_names = (
                lambda reference_graph_module: (
                    aot_intermediate_outputs,
                    aot_debug_handle_to_op_name,
                )
            )
            inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
                lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
            )
            inspector_instance._get_aot_debug_handle_to_stack_traces = (
                lambda reference_graph_module, resolved_graph_name: {}
            )

            # Test with edge_dialect_exported_program parameter
            df = inspector_instance.calculate_numeric_gap(
                distance="L1",
                reference_graph="edge_dialect_exported_program",
            )

            self.assertIsInstance(df, pd.DataFrame)
            self.assertEqual(len(df), 1)

    @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
    def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self):
        """
        Test that the numeric gap between AOT and runtime intermediate outputs
        for a ViT model lowered to XNNPACK delegate is within acceptable tolerance.

        This test verifies that when a Vision Transformer (ViT) model is exported
        and lowered to XNNPACK, the intermediate outputs during runtime closely
        match the expected AOT outputs, with gaps remaining within a small range.
        """
        from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
            XnnpackPartitioner,
        )
        from executorch.backends.xnnpack.utils.configs import (
            get_xnnpack_edge_compile_config,
        )
        from executorch.runtime import Method, Program, Runtime, Verification
        from torch import nn as nn

        class SingleBlockTransformer(nn.Module):
            def __init__(
                self,
                vocab_size: int,
                d_model: int = 256,
                nhead: int = 8,
                dim_feedforward: int = 1024,
                max_len: int = 512,
                dropout: float = 0.1,
            ):
                super().__init__()
                self.d_model = d_model

                self.tok_emb = nn.Embedding(vocab_size, d_model)
                self.pos_emb = nn.Embedding(max_len, d_model)

                # Single transformer encoder block
                self.block = nn.TransformerEncoderLayer(
                    d_model=d_model,
                    nhead=nhead,
                    dim_feedforward=dim_feedforward,
                    dropout=dropout,
                    batch_first=True,  # input: (B, T, C)
                    activation="gelu",
                    norm_first=True,
                )

                self.ln = nn.LayerNorm(d_model)
                self.head = nn.Linear(d_model, vocab_size, bias=False)

            def forward(
                self, input_ids: torch.Tensor, attn_mask: torch.Tensor | None = None
            ):
                """
                input_ids: (B, T) LongTensor
                attn_mask (optional): (B, T) where 1/True = keep, 0/False = pad
                """
                B, T = input_ids.shape
                pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)

                x = self.tok_emb(input_ids) + self.pos_emb(pos)  # (B, T, d_model)

                # Convert padding mask to TransformerEncoderLayer's expected format:
                # src_key_padding_mask: (B, T) with True = PAD (masked out)
                src_key_padding_mask = None
                if attn_mask is not None:
                    src_key_padding_mask = ~attn_mask.to(torch.bool)

                x = self.block(
                    x, src_key_padding_mask=src_key_padding_mask
                )  # (B, T, d_model)
                x = self.ln(x)
                logits = self.head(x)  # (B, T, vocab_size)
                return logits

        vocab_size = 5000
        model = SingleBlockTransformer(
            vocab_size=vocab_size, d_model=256, nhead=8, max_len=128
        )
        model_inputs = (
            torch.randint(0, vocab_size, (1, 32)),
            torch.ones(1, 32, dtype=torch.bool),
        )

        with tempfile.TemporaryDirectory() as temp_dir:
            # Export and lower model to XNNPACK delegate
            aten_model: ExportedProgram = export(model, model_inputs, strict=True)
            edge_program_manager = to_edge_transform_and_lower(
                aten_model,
                partitioner=[XnnpackPartitioner()],
                compile_config=get_xnnpack_edge_compile_config(),
                generate_etrecord=True,
            )

            et_program_manager: ExecutorchProgramManager = (
                edge_program_manager.to_executorch()
            )

            pte_path = os.path.join(temp_dir, "model.pte")
            et_program_manager.save(pte_path)

            # Dump ETRecord containing debug info for export progress
            etrecord = et_program_manager.get_etrecord()

            # Set the input for numerical discrepancy detection
            etrecord.update_representative_inputs(model_inputs)
            etrecord_path = os.path.join(temp_dir, "etrecord.bin")
            etrecord.save(etrecord_path)

            # Load and run PTE through Runtime API
            et_runtime: Runtime = Runtime.get()
            program: Program = et_runtime.load_program(
                pte_path,
                verification=Verification.Minimal,
                enable_etdump=True,
                debug_buffer_size=1024 * 1024 * 1024,  # 1GB
            )

            forward: Method = program.load_method("forward")
            forward.execute(model_inputs)

            # Dump ETDump recording execution data
            etdump_path = os.path.join(temp_dir, "etdump.etdp")
            debug_buffer_path = os.path.join(temp_dir, "debug_buffer.bin")
            program.write_etdump_result_to_file(etdump_path, debug_buffer_path)

            # Check if event tracer actually captured data (requires build-time config)
            if not os.path.exists(etdump_path):
                self.skipTest(
                    "Event tracer not enabled. Run with --config executorch.event_tracer_enabled=true"
                )

            # Create Inspector and calculate numeric gap
            inspector = Inspector(
                etdump_path=etdump_path,
                etrecord=etrecord_path,
                debug_buffer_path=debug_buffer_path,
            )

            df: pd.DataFrame = inspector.calculate_numeric_gap("MSE")

            # Verify we got results
            self.assertIsNotNone(df)
            self.assertGreater(len(df), 0)

            # Define tolerance threshold for numeric gap
            TOLERANCE = 1e-1

            # Check that each gap value is within acceptable tolerance
            for idx, row in df.iterrows():
                gap_value = row["gap"]
                # Handle case where gap might be a list
                if isinstance(gap_value, list):
                    gap_value = gap_value[0] if gap_value else 0.0

                runtime_ops = row["runtime_ops"]
                aot_ops = row["aot_ops"]

                self.assertLessEqual(
                    gap_value,
                    TOLERANCE,
                    f"Gap at index {idx} ( aot_ops: {aot_ops}, runtime_ops: {runtime_ops}) is {gap_value}, "
                    f"which exceeds tolerance {TOLERANCE}",
                )

            # Verify that stacktraces column exists and contains valid data
            self.assertIn("stacktraces", df.columns)
            for _, row in df.iterrows():
                stacktraces = row["stacktraces"]
                aot_ops = row["aot_ops"]
                # stacktraces should be a dict
                self.assertIsInstance(stacktraces, dict)
                # Each aot_op should have a corresponding entry in stacktraces
                for op_name in aot_ops:
                    self.assertIn(
                        op_name,
                        stacktraces,
                        f"Missing stack trace for operator {op_name}",
                    )
                    # Stack trace can be None or a string
                    # (None when model was exported without stack trace preservation)
                    stack_trace = stacktraces[op_name]
                    self.assertIsInstance(stack_trace, str)
                    # Stack traces should contain file information
                    self.assertIn(
                        "File",
                        stack_trace,
                        f"Stack trace for {op_name} doesn't contain file info",
                    )

    @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
    def test_intermediate_tensor_comparison_with_torch_export(self):
        """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower.

        Note: This test requires event tracer to be enabled. Run with:
            --config executorch.event_tracer_enabled=true
        """

        class SimpleTestModel(torch.nn.Module):
            """A simple test model for demonstration purposes."""

            def __init__(self, hidden_size: int = 32, num_layers: int = 2):
                super().__init__()
                self.layers = torch.nn.ModuleList(
                    [
                        torch.nn.Linear(hidden_size, hidden_size)
                        for _ in range(num_layers)
                    ]
                )
                self.activation = torch.nn.ReLU()
                self.output_layer = torch.nn.Linear(hidden_size, 10)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                x = self.activation(self.layers[0](x))
                y = self.activation(self.layers[1](x))
                return y, self.output_layer(x)

        # Create test model and inputs
        model = SimpleTestModel(hidden_size=32, num_layers=2)
        model.eval()

        # Create representative inputs (smaller for faster testing)
        batch_size, seq_len, hidden_size = 1, 8, 32
        input_tensor = torch.randn(batch_size, seq_len, hidden_size)
        example_inputs = (input_tensor,)
        representative_inputs = [example_inputs]

        with tempfile.TemporaryDirectory() as tmp_dir:
            model_path = os.path.join(tmp_dir, "model.pte")
            etrecord_path = os.path.join(tmp_dir, "etrecord.bin")

            # Step 1: Export using torch.export.export
            exported_program = torch.export.export(model, example_inputs)
            self.assertIsNotNone(exported_program)

            # Step 2: Lower to XNNPACK with generate_etrecord=True
            edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
            edge_program_manager = to_edge_transform_and_lower(
                exported_program,
                partitioner=[XnnpackPartitioner()],
                compile_config=edge_compile_config,
                generate_etrecord=True,
            )
            self.assertIsNotNone(edge_program_manager)

            # Step 3: Generate ETRecord from edge program manager
            # Step 4: Convert to executorch and save as PTE
            executorch_program = edge_program_manager.to_executorch()
            et_record = executorch_program.get_etrecord()
            self.assertIsNotNone(et_record)

            # Update with representative inputs
            flattened_x = pytree.tree_flatten(representative_inputs[0])[0]
            et_record.update_representative_inputs(flattened_x)
            et_record.save(etrecord_path)

            with open(model_path, "wb") as f:
                executorch_program.write_to_file(f)

            # Step 5: Test intermediate output comparison using pybind APIs
            # Read the PTE file
            with open(model_path, "rb") as f:
                pte_buffer = f.read()

            etdump_path = os.path.join(tmp_dir, "etdump.etdp")
            debug_buffer_path = os.path.join(tmp_dir, "debug_buffer.bin")

            # Load the PTE file with ETDump enabled using pybind API
            executorch_module = _load_for_executorch_from_buffer(
                pte_buffer,
                enable_etdump=True,
                debug_buffer_size=1024 * 1024,  # 1MB for testing
            )
            self.assertIsNotNone(executorch_module)

            # Run the model with the given input using pybind API
            flattened_x = pytree.tree_flatten(representative_inputs[0])[0]
            executorch_module.run_method("forward", tuple(flattened_x))

            # Write the ETDump results to a file using pybind API
            executorch_module.write_etdump_result_to_file(
                etdump_path, debug_buffer_path
            )

            # Check if event tracer actually captured data (requires build-time config)
            if not os.path.exists(etdump_path):
                self.skipTest(
                    "Event tracer not enabled. Run with --config executorch.event_tracer_enabled=true"
                )

            # Step 6: Use Inspector API to compare intermediate outputs
            inspector = Inspector(
                etdump_path=etdump_path,
                etrecord=etrecord_path,
                debug_buffer_path=debug_buffer_path,
            )
            self.assertIsNotNone(inspector)

            # Calculate numerical gap using SNR metric
            df = inspector.calculate_numeric_gap("SNR")

            # Verify that we got some intermediate tensor comparisons
            # The exact number will depend on the model structure and partitioning
            self.assertEqual(len(df), 2)

            # Verify that stacktraces column exists and contains valid data
            self.assertIn("stacktraces", df.columns)
            for _, row in df.iterrows():
                stacktraces = row["stacktraces"]
                aot_ops = row["aot_ops"]
                # stacktraces should be a dict
                self.assertIsInstance(stacktraces, dict)
                # Each aot_op should have a corresponding entry in stacktraces
                for op_name in aot_ops:
                    self.assertIn(
                        op_name,
                        stacktraces,
                        f"Missing stack trace for operator {op_name}",
                    )
                    # Stack trace can be None or a string
                    stack_trace = stacktraces[op_name]
                    self.assertIsInstance(stack_trace, str)
                    # Stack traces should contain file information
                    self.assertIn(
                        "File",
                        stack_trace,
                        f"Stack trace for {op_name} doesn't contain file info",
                    )

    def _gen_random_float_list(self) -> List[float]:
        return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]

    def _gen_random_runtime_output(
        self,
    ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
        return [torch.randn(RAW_DATA_SIZE)]

    @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
    def test_disable_debug_handle_validation_with_symbolic_shapes(self):
        """
        Test that demonstrates the issue with symbolic shape related nodes losing from_node info
        during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
        in propagate_back_debug_handle allows validation to be bypassed.
        """
        from executorch.devtools.inspector._inspector_utils import (
            propagate_back_debug_handle,
        )

        class SymbolicShapeModel(torch.nn.Module):
            """Model that will have symbolic shape related operations after export."""

            def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
                # This will create symbolic shape nodes during dynamic export
                batch_size = x.shape[0]
                x = x + torch.rand((batch_size, 1))
                # Masking operation that creates gt/lt nodes
                valid_mask = mask > 0.5
                x = torch.where(valid_mask, x, torch.zeros_like(x))
                return x

        # Create model and dynamic inputs
        model = SymbolicShapeModel()
        batch_size = 2
        seq_len = 4
        x = torch.randn(batch_size, seq_len)
        mask = torch.rand(batch_size, seq_len)
        example_inputs = (x, mask)

        # Export with dynamic shapes to create symbolic shape related nodes
        dynamic_shapes = {
            "x": {0: torch.export.Dim("batch_size", min=1, max=10)},
            "mask": {0: torch.export.Dim("batch_size", min=1, max=10)},
        }

        exported_program = torch.export.export(
            model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
        )

        """
        In this case origina aten graph has sym_size_int_2 node but when we look at
        nodes metadata in edge_program_manager, its sym_size node's from_node says
        sym_size_int_3 which is not in the original aten graph.
        """
        # Create edge program - this is where from_node info can be lost for symbolic shape nodes
        edge_program_manager: EdgeProgramManager = to_edge(exported_program)
        edge_program_manager_copy = copy.deepcopy(edge_program_manager)
        et_program_manager: ExecutorchProgramManager = (
            edge_program_manager.to_executorch()
        )

        with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
            etrecord_path = tmp_file.name

            # Generate ETRecord with the exported program (aten graph)
            generate_etrecord(
                etrecord_path,
                edge_program_manager_copy,
                et_program_manager,
                exported_program=exported_program,
            )

            # Create Inspector and get etrecord
            with patch.object(
                _inspector, "gen_etdump_object", return_value=None
            ), patch.object(EventBlock, "_gen_from_etdump"):
                inspector_instance = Inspector(
                    etdump_path=ETDUMP_PATH,
                    etrecord=etrecord_path,
                )

                # Extract the necessary values from the inspector's etrecord
                exported_program_from_etrecord = (
                    inspector_instance._etrecord.exported_program
                )
                export_graph_id = inspector_instance._etrecord.export_graph_id
                edge_dialect_program = inspector_instance._etrecord.edge_dialect_program

                # Ensure we have all the necessary components
                self.assertIsNotNone(exported_program_from_etrecord)
                self.assertIsNotNone(export_graph_id)
                self.assertIsNotNone(edge_dialect_program)

                # Test propagate_back_debug_handle with validation enabled (should fail or return False)
                # This demonstrates the issue with symbolic shape nodes losing from_node info
                validation_enabled_result = propagate_back_debug_handle(
                    exported_program_from_etrecord,
                    export_graph_id,
                    edge_dialect_program,
                    disable_debug_handle_valdiation=False,
                )

                # With validation enabled, it should return False when from_node info is lost
                self.assertFalse(
                    validation_enabled_result,
                    "propagate_back_debug_handle should return False when validation is enabled "
                    "and symbolic shape nodes lose from_node info",
                )

                # Test propagate_back_debug_handle with validation disabled (should succeed)
                # This shows how the disable_debug_handle_valdiation flag allows the function to work
                validation_disabled_result = propagate_back_debug_handle(
                    exported_program_from_etrecord,
                    export_graph_id,
                    edge_dialect_program,
                    disable_debug_handle_valdiation=True,
                )

                # With validation disabled, it should return True even when from_node info is lost
                self.assertTrue(
                    validation_disabled_result,
                    "propagate_back_debug_handle should return True when validation is disabled, "
                    "allowing best effort comparison even when from_node info is lost",
                )

    def _gen_random_events(self) -> List[Event]:
        events = []
        for i in range(2):
            events.append(
                # OPERATOR_CALL with debug_handle/instruction_id 0 and 2
                Event(
                    name="OPERATOR_CALL",
                    op_types=[OP_TYPE],
                    perf_data=PerfData(self._gen_random_float_list()),
                    debug_handles=i * 2,
                    _instruction_id=i * 2,
                    debug_data=self._gen_random_runtime_output(),
                )
            )
            events.append(
                # op_0/op_1 wiht empty op_types and with debug_handle/instruction_id 1 and 3
                Event(
                    name=f"op_{i}",
                    op_types=[],
                    perf_data=PerfData(self._gen_random_float_list()),
                    debug_handles=i * 2 + 1,
                    _instruction_id=i * 2 + 1,
                    debug_data=self._gen_random_runtime_output(),
                )
            )

        # op_2 with debug_handle/instruction_id 4
        events.append(
            Event(
                name="op_2",
                op_types=[OP_TYPE],
                perf_data=PerfData(self._gen_random_float_list()),
                debug_handles=4,
                debug_data=[torch.tensor([1.0, 2.0, 3.0])],
                _instruction_id=4,
            )
        )
        # op_3 also with debug_handle 4 but with instruction_id 5
        events.append(
            Event(
                name="op_3",
                op_types=[OP_TYPE],
                perf_data=PerfData(self._gen_random_float_list()),
                debug_handles=4,
                debug_data=[torch.tensor([4.0, 5.0, 6.0])],
                _instruction_id=5,
            )
        )

        # op_4 to op_7 with debug_handle 5 to 8 and instruction_id 6 to 9
        for i in range(4, EVENTS_SIZE - 2):
            events.append(
                Event(
                    name=f"op_{i}",
                    op_types=[OP_TYPE],
                    perf_data=PerfData(self._gen_random_float_list()),
                    debug_handles=i + 1,
                    debug_data=self._gen_random_runtime_output(),
                    _instruction_id=i + 2,
                )
            )
        return events


class TestInplaceOpsIntermediateOutput(unittest.TestCase):
    """
    Test suite for verifying that inplace operators correctly log intermediate
    outputs when the event tracer is enabled.

    This validates the fix for an issue where inplace ops converted by the
    reinplace_pass could cause logging errors because the output tensor's data
    pointer was null at the time of logging.

    Note: The reinplace_pass currently only supports converting index_put to
    index_put_ (see executorch/exir/passes/reinplace.py).
    """

    def _run_model_and_get_inspector(
        self,
        model: torch.nn.Module,
        example_inputs: tuple,
        run_reinplace_pass: bool = True,
    ) -> Inspector:
        """
        Helper method to export a model, run it with event tracing, and return
        an Inspector instance for verifying intermediate outputs.
        """
        model.eval()

        with tempfile.TemporaryDirectory() as tmp_dir:
            model_path = os.path.join(tmp_dir, "model.pte")
            etrecord_path = os.path.join(tmp_dir, "etrecord.bin")
            etdump_path = os.path.join(tmp_dir, "etdump.etdp")
            debug_buffer_path = os.path.join(tmp_dir, "debug_buffer.bin")

            # Step 1: Export the model
            exported_program = export(model, example_inputs)
            self.assertIsNotNone(exported_program)

            # Step 2: Convert to edge dialect
            edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
            edge_program = to_edge(exported_program, compile_config=edge_compile_config)
            self.assertIsNotNone(edge_program)

            # Keep a copy for etrecord
            edge_program_copy = to_edge(
                export(model, example_inputs), compile_config=edge_compile_config
            )

            # Step 3: Convert to executorch with reinplace_pass enabled
            executorch_config = ExecutorchBackendConfig(
                run_reinplace_pass=run_reinplace_pass
            )
            executorch_program = edge_program.to_executorch(config=executorch_config)
            self.assertIsNotNone(executorch_program)

            # Step 4: Generate ETRecord
            generate_etrecord(
                etrecord_path,
                edge_program_copy,
                executorch_program,
            )

            # Step 5: Save the PTE file
            with open(model_path, "wb") as f:
                executorch_program.write_to_file(f)

            # Step 6: Load and run with event tracing enabled
            with open(model_path, "rb") as f:
                pte_buffer = f.read()

            executorch_module = _load_for_executorch_from_buffer(
                pte_buffer,
                enable_etdump=True,
                debug_buffer_size=1024 * 1024,  # 1MB for testing
            )
            self.assertIsNotNone(executorch_module)

            # Run the model
            import torch.utils._pytree as pytree

            flattened_inputs = pytree.tree_flatten(example_inputs)[0]
            executorch_module.run_method("forward", tuple(flattened_inputs))

            # Write ETDump results
            executorch_module.write_etdump_result_to_file(
                etdump_path, debug_buffer_path
            )

            # Check if event tracer captured data
            if not os.path.exists(etdump_path):
                self.skipTest(
                    "Event tracer not enabled. Run with --config executorch.event_tracer_enabled=true"
                )

            # Step 7: Create Inspector and return
            inspector = Inspector(
                etdump_path=etdump_path,
                etrecord=etrecord_path,
                debug_buffer_path=debug_buffer_path,
            )
            return inspector

    def test_index_put_without_reinplace_pass(self):
        """
        Test that the model works correctly without the reinplace pass as a
        baseline comparison, and verify intermediate output correctness.
        """
        model = IndexPutModel()
        indices = torch.tensor([0, 2, 4])
        values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
        example_inputs = (indices, values)

        # Compute expected intermediate output of index_put
        # index_put on zeros(5,3) with indices [0,2,4] and values [[1,2,3],[4,5,6],[7,8,9]]
        # Result should be:
        # [[1, 2, 3],
        #  [0, 0, 0],
        #  [4, 5, 6],
        #  [0, 0, 0],
        #  [7, 8, 9]]
        expected_index_put_output = torch.zeros(5, 3)
        expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0])
        expected_index_put_output[2] = torch.tensor([4.0, 5.0, 6.0])
        expected_index_put_output[4] = torch.tensor([7.0, 8.0, 9.0])

        inspector = self._run_model_and_get_inspector(
            model, example_inputs, run_reinplace_pass=False
        )

        self.assertIsNotNone(inspector)
        self.assertGreater(len(inspector.event_blocks), 0)

        # Verify intermediate output correctness (same validation as with reinplace)
        found_index_put_output = False
        for event_block in inspector.event_blocks:
            for event in event_block.events:
                if hasattr(event, "debug_data") and event.debug_data is not None:
                    for debug_entry in event.debug_data:
                        if isinstance(debug_entry, torch.Tensor):
                            # Verify tensor has valid data pointer
                            self.assertIsNotNone(
                                debug_entry.data_ptr(),
                                "Intermediate output tensor should have valid data pointer",
                            )
                            self.assertNotEqual(
                                debug_entry.data_ptr(),
                                0,
                                "Intermediate output tensor data pointer should not be null",
                            )

                            # Check if this matches our expected index_put output shape
                            if debug_entry.shape == expected_index_put_output.shape:
                                if torch.allclose(
                                    debug_entry, expected_index_put_output, atol=1e-5
                                ):
                                    found_index_put_output = True

        self.assertTrue(
            found_index_put_output,
            "Expected to find index_put intermediate output with correct tensor data (without reinplace pass).",
        )

    def test_index_put_intermediate_output_data_correctness(self):
        """
        Test that the intermediate output values captured by the event tracer
        are valid tensors with correct data.

        This specifically validates that:
        1. The output tensor has a valid (non-null) data pointer
        2. The output tensor contains the correct values after index_put_
        """
        model = IndexPutModel()
        # Use simple values to verify correctness
        indices = torch.tensor([0, 1])
        values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
        example_inputs = (indices, values)

        # Compute expected intermediate output of index_put
        # index_put on zeros(5,3) with indices [0,1] and values [[1,2,3],[4,5,6]]
        # Result should be:
        # [[1, 2, 3],
        #  [4, 5, 6],
        #  [0, 0, 0],
        #  [0, 0, 0],
        #  [0, 0, 0]]
        expected_index_put_output = torch.zeros(5, 3)
        expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0])
        expected_index_put_output[1] = torch.tensor([4.0, 5.0, 6.0])

        inspector = self._run_model_and_get_inspector(
            model, example_inputs, run_reinplace_pass=True
        )

        self.assertIsNotNone(inspector)
        self.assertGreater(len(inspector.event_blocks), 0)

        total_events = sum(len(eb.events) for eb in inspector.event_blocks)
        self.assertGreater(
            total_events, 0, "Expected at least one event to be captured"
        )

        # Find and verify the index_put_ output
        found_index_put_output = False
        for event_block in inspector.event_blocks:
            for event in event_block.events:
                # Check if this event has debug_data (intermediate outputs)
                if hasattr(event, "debug_data") and event.debug_data is not None:
                    for debug_entry in event.debug_data:
                        if isinstance(debug_entry, torch.Tensor):
                            # Verify tensor has valid data pointer
                            self.assertIsNotNone(
                                debug_entry.data_ptr(),
                                "Intermediate output tensor should have valid data pointer",
                            )
                            self.assertNotEqual(
                                debug_entry.data_ptr(),
                                0,
                                "Intermediate output tensor data pointer should not be null",
                            )

                            # Check if this matches our expected index_put output shape
                            if debug_entry.shape == expected_index_put_output.shape:
                                # Verify the data is correct
                                if torch.allclose(
                                    debug_entry, expected_index_put_output, atol=1e-5
                                ):
                                    found_index_put_output = True

        # Assert that we found the expected index_put output with correct data
        # This validates that the intermediate output was properly logged
        # and contains the correct tensor values
        self.assertTrue(
            found_index_put_output,
            "Expected to find index_put intermediate output with correct tensor data. "
            "The output tensor should match the expected result of index_put operation.",
        )
