# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import subprocess
import sys
from pathlib import Path


_FLATC_ARGS = [
    "--python",
    "--python-typing",
    "--gen-object-api",
    "--gen-mutable",
    "-o",
    "exir/_serialize/generated",
    "--gen-compare",
]
_OLD_FROM_IMPORT = "from executorch_flatbuffer"
_NEW_FROM_IMPORT = "from executorch.exir._serialize.generated.executorch_flatbuffer"
_OLD_IMPORT_PREFIX = "import executorch_flatbuffer."
_NEW_IMPORT_PREFIX = (
    "from executorch.exir._serialize.generated.executorch_flatbuffer import "
)
_OLD_PACKAGE_IMPORT = "import executorch_flatbuffer"
_NEW_PACKAGE_IMPORT = (
    "from executorch.exir._serialize.generated import executorch_flatbuffer"
)


def _repo_root() -> Path:
    return Path(__file__).resolve().parents[2]


def _flatc_executable() -> str:
    return os.getenv("FLATC_EXECUTABLE", "flatc")


def _run_flatc(repo_root: Path) -> None:
    schema_dir = repo_root / "schema"
    schema_files = sorted(schema_dir.glob("*.fbs"))
    if not schema_files:
        raise SystemExit(f"No schema files found in {schema_dir}")

    output_dir = repo_root / "exir" / "_serialize" / "generated"
    output_dir.mkdir(parents=True, exist_ok=True)

    cmd = [_flatc_executable(), *_FLATC_ARGS, *[str(path) for path in schema_files]]
    subprocess.run(cmd, check=True, cwd=repo_root)


def _rewrite_imports(repo_root: Path) -> int:
    """Rewrite flatc imports to the in-tree package path.

    flatc derives the import path from the namespace declaration (executorch_flatbuffer).
    Modifying the namespace results in widespread breaking
    changes across both the Python and C++ codebases.
    so we rewrite the generated files to import
    from executorch.exir._serialize.generated.executorch_flatbuffer instead
    """
    generated_dir = (
        repo_root / "exir" / "_serialize" / "generated" / "executorch_flatbuffer"
    )
    if not generated_dir.is_dir():
        raise SystemExit(f"Expected generated directory at {generated_dir}")

    updated = 0
    for path in sorted(generated_dir.rglob("*.py")):
        contents = path.read_text(encoding="utf-8")
        new_lines = []
        for line in contents.splitlines(keepends=True):
            if _OLD_FROM_IMPORT in line:
                line = line.replace(_OLD_FROM_IMPORT, _NEW_FROM_IMPORT)
            stripped = line.lstrip()
            if stripped.startswith(_OLD_IMPORT_PREFIX):
                indent = line[: len(line) - len(stripped)]
                remainder = stripped[len(_OLD_IMPORT_PREFIX) :]
                line = f"{indent}{_NEW_IMPORT_PREFIX}{remainder}"
            elif stripped.startswith(_OLD_PACKAGE_IMPORT):
                indent = line[: len(line) - len(stripped)]
                remainder = stripped[len(_OLD_PACKAGE_IMPORT) :]
                line = f"{indent}{_NEW_PACKAGE_IMPORT}{remainder}"
            new_lines.append(line)
        new_contents = "".join(new_lines)
        if (
            "executorch_flatbuffer." in new_contents
            and _NEW_PACKAGE_IMPORT not in new_contents
        ):
            lines = new_contents.splitlines(keepends=True)
            insert_at = 0
            while insert_at < len(lines):
                stripped = lines[insert_at].strip()
                if not stripped or stripped.startswith("#"):
                    insert_at += 1
                    continue
                if stripped.startswith("from __future__ import"):
                    insert_at += 1
                    continue
                break
            lines.insert(insert_at, _NEW_PACKAGE_IMPORT + "\n")
            new_contents = "".join(lines)
        if new_contents != contents:
            path.write_text(new_contents, encoding="utf-8")
            updated += 1
    return updated


def _write_init_files(repo_root: Path) -> int:
    generated_dir = repo_root / "exir" / "_serialize" / "generated"
    generated_dir.mkdir(parents=True, exist_ok=True)

    top_init = generated_dir / "__init__.py"
    top_init_contents = "\n".join(
        [
            "# automatically generated by exir/_serialize/generate_program.py, do not modify",
            "",
            "from . import executorch_flatbuffer",
            "",
            "__all__ = [",
            '    "executorch_flatbuffer",',
            "]",
            "",
        ]
    )
    top_init.write_text(top_init_contents, encoding="utf-8")

    generated_fb_dir = generated_dir / "executorch_flatbuffer"
    if not generated_fb_dir.is_dir():
        raise SystemExit(f"Expected generated directory at {generated_fb_dir}")

    module_names = sorted(
        path.stem
        for path in generated_fb_dir.glob("*.py")
        if path.name != "__init__.py"
    )
    fb_init = generated_fb_dir / "__init__.py"
    fb_init_contents = [
        "# automatically generated by exir/_serialize/generate_program.py, do not modify",
        "",
    ]
    fb_init_contents.extend(f"from . import {name}" for name in module_names)
    fb_init_contents.extend(
        [
            "",
            "__all__ = [",
            *[f'    "{name}",' for name in module_names],
            "]",
            "",
        ]
    )
    fb_init.write_text("\n".join(fb_init_contents), encoding="utf-8")
    return 2


def main() -> int:
    repo_root = _repo_root()
    _run_flatc(repo_root)
    updated = _rewrite_imports(repo_root)
    updated += _write_init_files(repo_root)
    print(f"Updated {updated} file(s).", file=sys.stderr)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
