# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Script to combine multiple selected_prim_ops.h header files into a single header.
This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies.
"""

import argparse
import os
import sys
from pathlib import Path
from typing import List, Set


def read_header_file(file_path: Path) -> Set[str]:
    """
    Read a selected_prim_ops.h file and extract the macros and comments.

    Args:
        file_path: Path to the header file

    Returns:
        macros_set where macros_set contains unique macro defines
    """
    macros = set()

    try:
        with open(file_path, "r") as f:
            for line in f:
                line = line.strip()

                # Extract #define statements for prim ops
                if line.startswith("#define INCLUDE_") and not line.startswith(
                    "#define EXECUTORCH_ENABLE"
                ):
                    macros.add(line)
    except FileNotFoundError:
        print(f"Warning: Header file not found: {file_path}", file=sys.stderr)
    except Exception as e:
        print(f"Error reading {file_path}: {e}", file=sys.stderr)

    return macros


def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None:
    """
    Combine multiple selected_prim_ops.h files into a single header.

    Args:
        header_files: List of paths to header files to combine
        output_path: Path to output the combined header
    """
    all_macros = set()
    has_selective_build = False

    # Read all header files and collect unique macros
    for header_file_path in header_file_paths:
        header_file = Path(header_file_path) / "selected_prim_ops.h"
        if os.path.exists(header_file):
            macros = read_header_file(header_file)
            all_macros.update(macros)
            if len(all_macros) > 0:
                has_selective_build = True
        else:
            print(
                f"Warning: Header file does not exist: {header_file}", file=sys.stderr
            )

    # Generate combined header
    header_content = [
        "// Combined header for selective prim ops build",
        "// This file is auto-generated by combining multiple selected_prim_ops.h files",
        "// Do not edit manually.",
        "",
        "#pragma once",
        "",
    ]

    if all_macros and has_selective_build:
        header_content.extend(
            [
                "// Enable selective build for prim ops",
                "#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD",
                "",
                "// Combined prim ops macros from all dependencies",
            ]
        )

        # Sort macros for deterministic output
        sorted_macros = sorted(all_macros)
        header_content.extend(sorted_macros)
    else:
        header_content.extend(
            [
                "// No prim ops found in dependencies - all prim ops will be included",
                "// Selective build is disabled",
            ]
        )

    header_content.append("")

    # Write the combined header
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(header_content))


def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]:
    """
    Parse the output of a Buck query command to extract header file paths.

    Args:
        query_output_file: Path to the file containing the query output

    Returns:
        List of header file paths
    """
    header_file_paths = []
    assert (
        query_output_file[0] == "@"
    ), "query_output_file is not a valid file path, or it doesn't start with '@'."
    query_output_file = query_output_file[1:]

    with open(query_output_file, "r") as f:
        for line in f:
            # Extract the header file path from the query output
            header_file_paths += line.split()
    return header_file_paths


def main():
    parser = argparse.ArgumentParser(
        description="Combine multiple selected_prim_ops.h header files"
    )
    parser.add_argument(
        "--header_files",
        required=True,
        help="Comma-separated list of header file paths",
    )
    parser.add_argument(
        "--output_dir", required=True, help="Output directory for combined header"
    )

    args = parser.parse_args()
    import os

    header_file_paths = _get_header_file_paths_from_query_output(args.header_files)

    if not header_file_paths:
        print("Error: No header files provided", file=sys.stderr)
        sys.exit(1)

    # Generate output path
    output_path = os.path.join(args.output_dir, "selected_prim_ops.h")

    combine_prim_ops_headers(header_file_paths, output_path)


if __name__ == "__main__":
    main()
