# Copyright (c) Alibaba, Inc. and its affiliates.
from collections.abc import Mapping

import torch

from modelscope.outputs import ModelOutputBase


def to_device(batch, device, non_blocking=False):
    """Put the data to the target cuda device just before the forward function.
    Args:
        batch: The batch data out of the dataloader.
        device: (str | torch.device): The target device for the data.

    Returns: The data to the target device.

    """
    if isinstance(batch, ModelOutputBase):
        for idx in range(len(batch)):
            batch[idx] = to_device(batch[idx], device)
        return batch
    elif isinstance(batch, dict) or isinstance(batch, Mapping):
        if hasattr(batch, '__setitem__'):
            # Reuse mini-batch to keep attributes for prediction.
            for k, v in batch.items():
                batch[k] = to_device(v, device)
            return batch
        else:
            return type(batch)(
                {k: to_device(v, device)
                 for k, v in batch.items()})
    elif isinstance(batch, (tuple, list)):
        return type(batch)(to_device(v, device) for v in batch)
    elif isinstance(batch, torch.Tensor):
        return batch.to(device, non_blocking=non_blocking)
    else:
        return batch
