# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import torch_nested_numpify
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy)
class AccuracyMetric(Metric):
    """The metric computation class for classification classes.

    This metric class calculates accuracy for the whole input batches.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.preds = []
        self.labels = []

    def add(self, outputs: Dict, inputs: Dict):
        label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
        ground_truths = inputs[label_name]
        eval_results = None
        for key in [
                OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
                OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES
        ]:
            if key in outputs and outputs[key] is not None:
                eval_results = outputs[key]
                break
        assert type(ground_truths) == type(eval_results)
        ground_truths = torch_nested_numpify(ground_truths)
        for truth in ground_truths:
            self.labels.append(truth)
        eval_results = torch_nested_numpify(eval_results)
        for result in eval_results:
            if isinstance(truth, str):
                if isinstance(result, list):
                    result = result[0]
                assert isinstance(result, str), 'both truth and pred are str'
                self.preds.append(remove_space_between_chinese_chars(result))
            else:
                self.preds.append(result)

    def evaluate(self):
        assert len(self.preds) == len(self.labels)
        return {
            MetricKeys.ACCURACY: (np.asarray([
                pred == ref for pred, ref in zip(self.preds, self.labels)
            ])).mean().item()
        }

    def merge(self, other: 'AccuracyMetric'):
        self.preds.extend(other.preds)
        self.labels.extend(other.labels)

    def __getstate__(self):
        return self.preds, self.labels

    def __setstate__(self, state):
        self.__init__()
        self.preds, self.labels = state
