# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import os
import typing
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Set

import torch
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
    ReplaceViewCopyWithViewPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch.export.passes import move_to_device_pass


class COMPILE_SPEC_KEYS(Enum):
    METHOD_NAME = "method_name"


@experimental(
    "This API and all of aoti-driven backend related functionality are experimental."
)
class AotiBackend(ABC):
    """
    Base mixin class for AOTInductor-based backends.

    This class provides common functionality for compiling models using AOTInductor
    with different device targets (CUDA, Metal, etc.).

    This is a mixin class, not an actual backend object, for aoti-driven backends.
    Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both
    BackendDetails and AotiBackend to get the full functionality.
    """

    @classmethod
    @abstractmethod
    def get_device_name(cls) -> str:
        """Return the device name for this backend (e.g., 'cuda', 'metal')."""
        pass

    @classmethod
    @abstractmethod
    def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
        """Return the set of supported fallback kernels for this backend."""
        pass

    @classmethod
    @abstractmethod
    def get_decomposition_table(cls) -> Dict[Any, Any]:
        """Return the decomposition table for this backend."""
        pass

    @classmethod
    @abstractmethod
    def get_aoti_compile_options(
        cls, compile_specs: List[CompileSpec]
    ) -> Dict[str, typing.Any]:
        """Return the AOTInductor compilation options for this backend."""
        pass

    @classmethod
    @abstractmethod
    def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
        """Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
        pass

    @classmethod
    def save_data_externally(cls) -> bool:
        """
        Return whether to save the named data map to an external .ptd file.

        If True, the SO blob and weights blob will be saved to a separate .ptd file
        (e.g., aoti_cuda_blob.ptd) that must be provided at runtime.
        If False, the data will be merged into the .pte file.

        Default is False (merge with .pte file). Subclasses can override this.
        """
        return False

    @classmethod
    def get_extra_aoti_compile_context_manager(cls):
        """Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
        return contextlib.nullcontext()

    @classmethod
    @contextlib.contextmanager
    def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
        """
        Context manager to collect unsupported fallback kernels during compilation.
        Monitors both extern kernel calls and runtime lookup.
        """
        supported_kernels = cls.get_supported_fallback_kernels()

        original_generate_c_shim_extern_kernel_call = (
            CppWrapperCpu.generate_c_shim_extern_kernel_call
        )
        original_generate_fallback_kernel_with_runtime_lookup_aot = (
            CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot
        )

        def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
            self, kernel: str, *args: Any, **kwargs: Any
        ) -> None:
            if kernel not in supported_kernels:
                missing_fallback_kernels.add(kernel)

            return original_generate_c_shim_extern_kernel_call(
                self, kernel, *args, **kwargs
            )

        def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
            self, op_overload: Any, *args: Any, **kwargs: Any
        ) -> None:
            kernel_name = getattr(op_overload, "_name", str(op_overload))
            if kernel_name not in supported_kernels:
                missing_fallback_kernels.add(kernel_name)

            return original_generate_fallback_kernel_with_runtime_lookup_aot(
                self, op_overload, *args, **kwargs
            )

        CppWrapperCpu.generate_c_shim_extern_kernel_call = (
            generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
        )
        CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels

        try:
            yield
        finally:
            CppWrapperCpu.generate_c_shim_extern_kernel_call = (
                original_generate_c_shim_extern_kernel_call
            )
            CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
                original_generate_fallback_kernel_with_runtime_lookup_aot
            )

    @classmethod
    def preprocess(
        cls,
        edge_program: ExportedProgram,
        compile_specs: List[CompileSpec],
    ) -> PreprocessResult:
        """
        Preprocess the edge program and compile it using AOTInductor.
        Weights are always separated from the SO file.
        """
        device_name = cls.get_device_name()
        decomposition_table = cls.get_decomposition_table()
        options = cls.get_aoti_compile_options(compile_specs)

        # Move the edge_program to the target device
        device_edge_program = move_to_device_pass(
            edge_program, device_name if device_name != "metal" else "mps"
        )

        # Replace view_copy with view
        ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)

        # Apply custom backend-specific passes
        custom_passes = cls.get_custom_passes(compile_specs)
        for custom_pass in custom_passes:
            if getattr(custom_pass, "requires_exported_program", False):
                custom_pass(device_edge_program)
            else:
                custom_pass(device_edge_program.graph_module)

        # Run decompositions if any
        if decomposition_table:
            device_edge_program = device_edge_program.run_decompositions(
                decomposition_table
            )

        edge_program_module = device_edge_program.module()

        # Grab all input placeholders from the graph
        user_input_names = device_edge_program.graph_signature.user_inputs
        user_input_placeholders = []
        for node in device_edge_program.graph.nodes:
            if node.op == "placeholder" and node.name in user_input_names:
                user_input_placeholders.append(node.meta["val"])

        # Track missing fallback kernels
        missing_fallback_kernels: Set[str] = set()

        # Compile with fallback kernel collection
        with cls.collect_unsupported_fallback_kernels(
            missing_fallback_kernels
        ), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
            paths = torch._inductor.aot_compile(
                edge_program_module, tuple(user_input_placeholders), options=options
            )

            if len(missing_fallback_kernels) > 0:
                formatted_kernels = "\n  - ".join(sorted(missing_fallback_kernels))
                method_name = cls.method_name_from_compile_specs(compile_specs)
                raise RuntimeError(
                    f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n  - {formatted_kernels}\n"
                    "Please add them to the AOTI backend."
                )

        # Extract paths - weights are always separated
        so_path = None
        blob_path = None

        if isinstance(paths, list):
            for path in paths:
                if path.endswith(".wrapper.so"):
                    so_path = path
                elif path.endswith(".wrapper_weights.blob"):
                    blob_path = path
        else:
            so_path = paths

        if so_path is None or blob_path is None:
            raise RuntimeError(
                f"Could not find required files in compiled paths, got {paths}"
            )

        # Read SO file
        with open(so_path, "rb") as f:
            so_data = f.read()

        # Read weights blob
        with open(blob_path, "rb") as f:
            blob_data = f.read()

        # Create named data store
        named_data_store = NamedDataStore()
        method_name = cls.method_name_from_compile_specs(compile_specs)

        named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
        # Determine whether to save named data externally based on backend setting
        # External: save to separate .ptd file, otherwise merge with .pte file
        external_tag = (
            f"aoti_{device_name}_blob" if cls.save_data_externally() else None
        )

        named_data_store.add_named_data(
            method_name + "_weights_blob", blob_data, 1, external_tag
        )

        # Clean up the generated files
        os.remove(so_path)
        os.remove(blob_path)

        return PreprocessResult(
            processed_bytes=b"",
            debug_handle_map={},
            data_store_output=named_data_store.get_named_data_store_output(),
        )

    @classmethod
    def generate_method_name_compile_spec(
        cls,
        method_name: str,
    ) -> CompileSpec:
        """
        Generate a CompileSpec for the given method name.
        """
        return CompileSpec(
            COMPILE_SPEC_KEYS.METHOD_NAME.value,
            method_name.encode("utf-8"),
        )

    @classmethod
    def method_name_from_compile_specs(
        cls,
        compile_specs: List[CompileSpec],
    ) -> str:
        """
        Extract the method name from the compile specs.
        """
        for spec in compile_specs:
            if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
                return spec.value.decode("utf-8")
        raise RuntimeError(
            f"Could not find method name in compile specs: {compile_specs}"
        )
