# mypy: allow-untyped-defs
import functools
from typing import Callable, Union
from typing_extensions import TypeVarTuple

import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import (
    disable_proxy_modes_tracing,
    ProxyTorchDispatchMode,
    track_tensor_tree,
)

from .utils import (
    _from_fun,
    _stack_pytree,
    _unstack_pytree,
    create_bw_fn,
    fill_none_with_masks,
    filter_with_masks,
    materialize_as_graph,
    save_tensors_and_symints_for_backward,
    saved_tensors_and_symints,
    split_into_chunks,
)


class MapImpl(HigherOrderOperator):
    def __init__(self):
        super().__init__("map_impl")

    def __call__(self, *args, **kwargs):
        return super().__call__(*args, **kwargs)


map_impl = MapImpl()


def map(
    f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree],
    xs: Union[pytree.PyTree, torch.Tensor],
    *args: TypeVarTuple,
):
    r"""
    Performs a map of f with xs. Intuitively, you can think of the semantic being:

    out = []
    for idx in len(xs.size(0)):
        xs_sliced = xs.select(0, idx)
        out.append(f(xs_sliced, *args))
    torch.stack(out)

    .. warning::
        `torch._higher_order_ops.map` is a prototype feature in PyTorch. It currently
        does not support autograd and you may run into miscompiles.
        Read more about feature classification at:
        https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype


    Args:
        f (Callable): a callable that takes an input x, that could either be a single Tensor
            or a nested dict, list of tensors and some additional inputs
        xs: the inputs that're to be mapped over. We'll iterate over the first dim of each x
            and perform f on each slice.

        *args: additional arguments provided to each step of f. They could also be omitted and
            map is able to automatically figure out the read dependency.

    Return:
        the stacked output for each step of f

    Example:

        def f(xs):
            return xs[0] + xs[1] + const1 + const2

        xs = [torch.randn(2, 3), torch.randn(2, 3)]
        const1 = torch.randn(2, 3)
        const2 = torch.randn(2, 3)
        # returns a tensor of shape [2, 2, 3]
        torch._higher_order_ops.map(f, xs)

    """
    flat_xs, xs_spec = pytree.tree_flatten(xs)
    flat_args, args_spec = pytree.tree_flatten(args)
    if not all(isinstance(t, torch.Tensor) for t in flat_xs):
        raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")

    shapes = [xs.shape for xs in flat_xs]
    leading_dim_size = shapes[0][0]
    if leading_dim_size == 0:
        raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")

    if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
        raise RuntimeError(
            f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
        )

    def run_flattened_map(f, flat_xs, flat_args):
        def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs):
            xs = pytree.tree_unflatten(flat_args[:num_xs], xs_tree_spec)
            args = pytree.tree_unflatten(flat_args[num_xs:], args_tree_spec)
            return f(xs, *args)

        inner_f = functools.partial(
            wrapped_fn,
            f=f,
            xs_tree_spec=xs_spec,
            args_tree_spec=args_spec,
            num_xs=len(flat_xs),
        )
        return map_impl(inner_f, flat_xs, flat_args)

    from torch._higher_order_ops.utils import _maybe_compile_and_run_fn

    return _maybe_compile_and_run_fn(run_flattened_map, f, flat_xs, flat_args)


class MapAutogradOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f, num_mapped_args, *flat_args):
        ctx._f = f
        ctx._num_mapped_args = num_mapped_args
        ctx._num_pos_args = len(flat_args) - num_mapped_args

        # We snapshot the dispatch keys in forward for materializing the
        # the bw_graph in backward.
        ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
        ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
        save_tensors_and_symints_for_backward(ctx, flat_args)
        with torch._C._AutoDispatchBelowAutograd():
            return (
                *map_impl(f, flat_args[:num_mapped_args], flat_args[num_mapped_args:]),
            )

    @staticmethod
    def backward(ctx, *flat_grads):
        fw_args = saved_tensors_and_symints(ctx)
        num_mapped_args = ctx._num_mapped_args
        num_pos_args = ctx._num_pos_args
        num_grads = len(flat_grads)

        fw_mapped_args, pos_args = split_into_chunks(
            fw_args,
            [
                num_mapped_args,
                num_pos_args,
            ],
        )

        bw_f = create_bw_fn(ctx._f, fw_args)

        grads_tensor_masks = []

        # Create a wrapper around thefor the bw_f
        def bw_f_wrapper(*args):
            nonlocal grads_tensor_masks

            # Dissect args and re-order them for the ``ctx._bw_f``
            # args provided to the wrapper are composed of [*fw_mapped_args, *flat_grads, *pos_args]
            # The content of ``bw_f_tangents`` are the upstream gradients, i.e. flat_grads
            # The content of ``bw_f_primals`` are the fw_args, i.e., [*fw_mapped_args, *pos_args]
            # The bw_f requires *bw_f_primals, *bw_f_tangents
            fw_m_args, bw_f_tangents, pos_args = split_into_chunks(
                args, [num_mapped_args, num_grads, num_pos_args]
            )
            bw_f_primals = *fw_m_args, *pos_args
            gradients = bw_f(*bw_f_primals, *bw_f_tangents)
            grads_tensor_masks = [
                True if isinstance(out, torch.Tensor) else out for out in gradients
            ]
            return filter_with_masks(gradients, grads_tensor_masks)

        def construct_args_single_step_bw():
            unwrapped_mapped_xs = pytree.tree_map(_from_fun, fw_mapped_args)
            example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
            unwrapped_grads = pytree.tree_map(_from_fun, flat_grads)
            example_grads = _unstack_pytree(unwrapped_grads)[0]
            example_pos_args = [
                _from_fun(arg) if isinstance(arg, torch.Tensor) else arg
                for arg in pos_args
            ]
            return *example_xs, *example_grads, *example_pos_args

        with suspend_functionalization(), disable_functional_mode():
            with disable_proxy_modes_tracing():
                args_single_step_bw = construct_args_single_step_bw()

            # TODO: we need to materialize the bw graphs because dynamo is unable to
            # trace through the joint function when torch.compile torch.autograd.grad.
            fn_bw_gm = materialize_as_graph(
                bw_f_wrapper,
                args_single_step_bw,
                ctx._fw_include_key_set,
                ctx._fw_exclude_key_set,
                force_enable_grad=True,
            )

        grads = map_impl(fn_bw_gm, fw_mapped_args + flat_grads, pos_args)

        return None, None, *fill_none_with_masks(grads, grads_tensor_masks)


def trace_map(proxy_mode, func_overload, f, xs, pos_args):
    with disable_proxy_modes_tracing():
        example_input = _unstack_pytree(xs)[0]

        body_graph = f

        body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args)

    next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_")

    proxy_mode.tracer.root.register_module(next_name, body_graph)

    fake_outs = map_impl(body_graph, xs, pos_args)

    node_args = (body_graph, list(xs), list(pos_args))
    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
    out_proxy = proxy_mode.tracer.create_proxy(
        "call_function", func_overload, proxy_args, {}, name="map_impl"
    )
    return track_tensor_tree(
        fake_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
    )


@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, xs, pos_args):
    pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)]
    return _stack_pytree(pytrees)


@map_impl.py_autograd_impl
def map_autograd(f, xs, pos_args):
    num_mapped_args = len(xs)
    flat_out = MapAutogradOp.apply(f, num_mapped_args, *xs, *pos_args)
    return flat_out


@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
    return trace_map(mode, map_impl, f, xs, args)


@map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(mode, f, xs, args):
    with mode:
        return map_dense(f, xs, args)


@map_impl.py_functionalize_impl
def map_functionalize(ctx, f, xs, pos_args):
    from torch._higher_order_ops.utils import _check_alias_and_mutation

    unwrapped_xs = ctx.unwrap_tensors(xs)
    unwrapped_args = ctx.unwrap_tensors(pos_args)
    wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))

    with ctx.redispatch_to_next():
        example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
        pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
        _check_alias_and_mutation(f, example_inputs, "map", pre_dispatch)
        map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
        return ctx.wrap_tensors(map_return)


def _fake_map(f, x, *args):
    from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree

    x_pytrees = _unstack_pytree(x)
    zs = []
    for xp in x_pytrees:
        zs.append(f(xp, *args))
    return _stack_pytree(zs)
