from typing import Union

import torch
import torch.utils._pytree as pytree
from torch.export.exported_program import ExportedProgram


__all__ = ["move_to_device_pass"]


def move_to_device_pass(
    ep: ExportedProgram, location: Union[torch.device, str, dict[str, str]]
) -> ExportedProgram:
    """
    Move the exported program to the given device.

    Args:
        ep (ExportedProgram): The exported program to move.
        location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
            If a string, it is interpreted as a device name.
            If a dict, it is interpreted as a mapping from
            the existing device to the intended one

    Returns:
        ExportedProgram: The moved exported program.
    """

    def _get_new_device(
        curr_device: torch.device,
        location: Union[torch.device, str, dict[str, str]],
    ) -> str:
        if isinstance(location, dict):
            if str(curr_device) in location.keys():
                return location[str(curr_device)]
            else:
                return str(curr_device)
        else:
            return str(location)

    # move all the state_dict
    for k, v in ep.state_dict.items():
        if isinstance(v, torch.nn.Parameter):
            ep._state_dict[k] = torch.nn.Parameter(
                v.to(_get_new_device(v.device, location)),
                v.requires_grad,
            )
        else:
            ep._state_dict[k] = v.to(_get_new_device(v.device, location))

    # move all the constants
    for k, v in ep.constants.items():
        if isinstance(v, torch.Tensor):
            ep._constants[k] = v.to(_get_new_device(v.device, location))

    # move example_inputs if they exist
    if ep.example_inputs is not None:
        args, kwargs = ep.example_inputs
        moved_args = pytree.tree_map_only(
            torch.Tensor,
            lambda tensor: tensor.to(_get_new_device(tensor.device, location)),
            args,
        )
        moved_kwargs = pytree.tree_map_only(
            torch.Tensor,
            lambda tensor: tensor.to(_get_new_device(tensor.device, location)),
            kwargs,
        )
        ep._example_inputs = (moved_args, moved_kwargs)

    for m in ep.graph_module.modules():
        if isinstance(m, torch.fx.GraphModule):
            for node in m.graph.nodes:
                # move all the nodes kwargs with burnt-in device
                if "device" in node.kwargs:
                    kwargs = node.kwargs.copy()
                    kwargs["device"] = _get_new_device(kwargs["device"], location)
                    node.kwargs = kwargs

                if (
                    node.op == "call_function"
                    and node.target == torch.ops.aten.to.device
                ):
                    args = list(node.args)
                    args[1] = _get_new_device(args[1], location)
                    node.args = tuple(args)

                # move all the tensor metadata
                node.meta["val"] = pytree.tree_map(
                    lambda v: v.to(_get_new_device(v.device, location))
                    if isinstance(v, torch.Tensor)
                    else v,
                    node.meta.get("val"),
                )

    ep.validate()
    return ep
