# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import argparse
import os
import sys
from typing import Any, List

from torchgen.code_template import CodeTemplate  # type: ignore[import-not-found]


selected_prim_ops_h_template_str = """#pragma once
/**
 * Generated by executorch/codegen/tools/gen_selected_prim_ops.py
 */

$defines
"""
selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str)


def normalize_op_name(op_name: str) -> str:
    """
    Normalize an operator name to a macro-safe format.
    Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT"
    or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT"
    """
    # Remove namespace separator and replace with underscore
    normalized = op_name.replace("::", "_")
    # Replace dots with underscores
    normalized = normalized.replace(".", "_")
    # Convert to uppercase
    normalized = normalized.upper()
    # Add INCLUDE_ prefix
    normalized = f"INCLUDE_{normalized}"
    return normalized


def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None:
    """
    Generate selected_prim_ops.h from a list of prim op names.

    Args:
        prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"]
        output_dir: Directory where to write selected_prim_ops.h
    """
    # Generate #define statements for each op
    defines = []
    for op_name in prim_op_names:
        macro_name = normalize_op_name(op_name)
        defines.append(f"#define {macro_name}")

    # Join all defines with newlines
    defines_str = "\n".join(defines)

    # Generate header content
    header_contents = selected_prim_ops_h_template.substitute(defines=defines_str)

    # Write to file
    selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h")
    with open(selected_prim_ops_path, "wb") as out_file:
        out_file.write(header_contents.encode("utf-8"))


def main(argv: List[Any]) -> None:
    parser = argparse.ArgumentParser(description="Generate selected prim ops header")
    parser.add_argument(
        "--prim-op-names",
        "--prim_op_names",
        help="Comma-separated list of prim op names to include",
        required=True,
    )
    parser.add_argument(
        "--output-dir",
        "--output_dir",
        help="The directory to store the output header file (selected_prim_ops.h)",
        required=True,
    )

    options = parser.parse_args(argv)

    # Parse comma-separated prim op names
    prim_op_names = [
        name.strip() for name in options.prim_op_names.split(",") if name.strip()
    ]

    write_selected_prim_ops(prim_op_names, options.output_dir)


if __name__ == "__main__":
    main(sys.argv[1:])
