# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from typing import Any, Dict, Union

from modelscope.models.base.base_model import Model
from modelscope.utils.config import ConfigDict
from modelscope.utils.logger import get_logger

logger = get_logger()

Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[Dict[str, Tensor], Model]


class Head(ABC):
    """The head base class is for the tasks head method definition

    """

    def __init__(self, **kwargs):
        self.config = ConfigDict(kwargs)

    @abstractmethod
    def forward(self, *args, **kwargs) -> Dict[str, Any]:
        """
        This method will use the output from backbone model to do any
        downstream tasks. Receive The output from backbone model.

        Returns (Dict[str, Any]): The output from downstream task.
        """
        pass

    @abstractmethod
    def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
        """
        compute loss for head during the finetuning.

        Returns (Dict[str, Any]): The loss dict
        """
        pass
