"""
Python polyfills for builtins
"""

from __future__ import annotations

import builtins
import functools
import operator
from typing import Callable, TYPE_CHECKING, TypeVar

from ..decorators import substitute_in_graph


if TYPE_CHECKING:
    from collections.abc import Iterable


__all__ = [
    "all",
    "any",
    "enumerate",
    "sum",
]


_T = TypeVar("_T")


@substitute_in_graph(builtins.all, can_constant_fold_through=True)
def all(iterable: Iterable[object], /) -> bool:
    for elem in iterable:
        if not elem:
            return False
    return True


@substitute_in_graph(builtins.any, can_constant_fold_through=True)
def any(iterable: Iterable[object], /) -> bool:
    for elem in iterable:
        if elem:
            return True
    return False


@substitute_in_graph(builtins.enumerate, is_embedded_type=True)  # type: ignore[arg-type]
def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]:
    if not isinstance(start, int):
        raise TypeError(
            f"{type(start).__name__!r} object cannot be interpreted as an integer"
        )

    for x in iterable:
        yield start, x
        start += 1


@substitute_in_graph(builtins.sum, can_constant_fold_through=True)  # type: ignore[arg-type]
def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T:  # type: ignore[assignment]
    return functools.reduce(operator.add, iterable, start)


class _CallableIterator:
    def __init__(self, fn, sentinel):  # type: ignore[no-untyped-def]
        self.fn = fn
        self.sentinel = sentinel

    def __iter__(self):  # type: ignore[no-untyped-def]
        return self

    def __next__(self):  # type: ignore[no-untyped-def]
        # The iterator created in this case will call object with no arguments
        # for each call to its __next__() method;
        r = self.fn()

        # If the value returned is equal to sentinel, StopIteration will be raised
        if r == self.sentinel:
            raise StopIteration

        # otherwise the value will be returned.
        return r


class _SENTINEL_MISSING:
    pass


# TODO(guilhermeleobas): use substitute_in_graph for iter()
def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /):  # type: ignore[no-untyped-def]
    # Without a second argument, object must be a collection object which supports
    # the iterable (__iter__) or the sequence protocol (__getitem__ with an integer
    # starting at 0)
    if sentinel is _SENTINEL_MISSING:
        iterable = fn_or_iterable
        if hasattr(iterable, "__iter__"):
            iterator = iterable.__iter__()
            if hasattr(iterator, "__next__"):
                return iterator
            else:
                raise TypeError(f"'{type(iterator)}' object is not iterable")
        if hasattr(iterable, "__getitem__"):
            # Needs to be a new function to avoid iter becoming a generator
            def sequence_protocol(iterable):  # type: ignore[no-untyped-def]
                i = 0
                while True:
                    try:
                        yield iterable.__getitem__(i)
                        i += 1
                    except IndexError:
                        break

            return sequence_protocol(iterable)
        raise TypeError(f"'{type(iterable)}' object is not iterable")
    else:
        # If the second argument, sentinel, is given, then object must be a
        # callable object.
        fn = fn_or_iterable

        if not isinstance(fn, Callable):  # type: ignore[arg-type]
            raise TypeError("iter(v, w): v must be a callable")

        return _CallableIterator(fn, sentinel)
