# Copyright (c) Alibaba, Inc. and its affiliates.

import time
from typing import Dict, Optional, Tuple, Union

import numpy as np

from modelscope.metainfo import Trainers
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.logger import get_logger

PATH = None
logger = get_logger(PATH)


@TRAINERS.register_module(module_name=Trainers.bert_sentiment_analysis)
class SequenceClassificationTrainer(BaseTrainer):

    def __init__(self, cfg_file: str, *args, **kwargs):
        """ A trainer is used for Sequence Classification

        Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset

        Args:
            cfg_file (str): the path of config file
        Raises:
            ValueError: _description_
        """
        super().__init__(cfg_file)

    def train(self, *args, **kwargs):
        logger.info('Train')
        ...

    def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]:
        """get attribute from config, if the attribute does exist, return false

        Example:

        >>> self.__attr_is_exist("model path")
        >>> out: (model-path, "/workspace/bert-base-sst2")
        >>> self.__attr_is_exist("model weights")
        >>> out: (model-weights, False)

        Args:
            attr (str): attribute str, "model path" -> config["model"][path]

        Returns:
            Tuple[Union[str, bool]]:[target attribute name, the target attribute or False]
        """
        paths = attr.split(' ')
        attr_str: str = '-'.join(paths)
        target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None

        for path_ in paths[1:]:
            if not hasattr(target, path_):
                return attr_str, False
            target = target[path_]

        if target and target != '':
            return attr_str, target
        return attr_str, False

    def evaluate(self,
                 checkpoint_path: Optional[str] = None,
                 *args,
                 **kwargs) -> Dict[str, float]:
        """evaluate a dataset

        evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
        does not exist, read from the config file.

        Args:
            checkpoint_path (Optional[str], optional): the model path. Defaults to None.

        Returns:
            Dict[str, float]: the results about the evaluation
            Example:
            {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
        """
        import torch
        from easynlp.appzoo import load_dataset
        from easynlp.appzoo.dataset import GeneralDataset
        from easynlp.appzoo.sequence_classification.model import \
            SequenceClassification
        from easynlp.utils import losses
        from sklearn.metrics import f1_score
        from torch.utils.data import DataLoader

        raise_str = 'Attribute {} is not given in config file!'

        metrics = self.__attr_is_exist('evaluation metrics')
        eval_batch_size = self.__attr_is_exist('evaluation batch_size')
        test_dataset_path = self.__attr_is_exist('dataset valid file')

        attrs = [metrics, eval_batch_size, test_dataset_path]
        for attr_ in attrs:
            if not attr_[-1]:
                raise AttributeError(raise_str.format(attr_[0]))

        if not checkpoint_path:
            checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1]
            if not checkpoint_path:
                raise ValueError(
                    'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!'
                )

        max_sequence_length = kwargs.get(
            'max_sequence_length',
            self.__attr_is_exist('evaluation max_sequence_length')[-1])
        if not max_sequence_length:
            raise ValueError(
                'Argument max_sequence_length must be passed '
                'if the evaluation-max_sequence_length does not exist in config file!'
            )

        # get the raw online dataset
        raw_dataset = load_dataset(*test_dataset_path[-1].split('/'))
        valid_dataset = raw_dataset['validation']

        # generate a standard dataloader
        pre_dataset = GeneralDataset(valid_dataset, checkpoint_path,
                                     max_sequence_length)
        valid_dataloader = DataLoader(
            pre_dataset,
            batch_size=eval_batch_size[-1],
            shuffle=False,
            collate_fn=pre_dataset.batch_fn)

        # generate a model
        model = SequenceClassification.from_pretrained(checkpoint_path)

        # copy from easynlp (start)
        model.eval()
        total_loss = 0
        total_steps = 0
        total_samples = 0
        hit_num = 0
        total_num = 0

        logits_list = list()
        y_trues = list()

        total_spent_time = 0.0
        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        model.to(device)
        for _step, batch in enumerate(valid_dataloader):
            try:
                batch = {
                    # key: val.cuda() if isinstance(val, torch.Tensor) else val
                    # for key, val in batch.items()
                    key:
                    val.to(device) if isinstance(val, torch.Tensor) else val
                    for key, val in batch.items()
                }
            except RuntimeError:
                batch = {key: val for key, val in batch.items()}

            infer_start_time = time.time()
            with torch.no_grad():
                label_ids = batch.pop('label_ids')
                outputs = model(batch)
            infer_end_time = time.time()
            total_spent_time += infer_end_time - infer_start_time

            assert 'logits' in outputs
            logits = outputs['logits']

            y_trues.extend(label_ids.tolist())
            logits_list.extend(logits.tolist())
            hit_num += torch.sum(
                torch.argmax(logits, dim=-1) == label_ids).item()
            total_num += label_ids.shape[0]

            if len(logits.shape) == 1 or logits.shape[-1] == 1:
                tmp_loss = losses.mse_loss(logits, label_ids)
            elif len(logits.shape) == 2:
                tmp_loss = losses.cross_entropy(logits, label_ids)
            else:
                raise RuntimeError

            total_loss += tmp_loss.mean().item()
            total_steps += 1
            total_samples += valid_dataloader.batch_size
            if (_step + 1) % 100 == 0:
                total_step = len(
                    valid_dataloader.dataset) // valid_dataloader.batch_size
                logger.info('Eval: {}/{} steps finished'.format(
                    _step + 1, total_step))

        logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
            total_spent_time, total_spent_time * 1000 / total_samples))

        eval_loss = total_loss / total_steps
        logger.info('Eval loss: {}'.format(eval_loss))

        logits_list = np.array(logits_list)
        eval_outputs = list()
        for metric in metrics[-1]:
            if metric.endswith('accuracy'):
                acc = hit_num / total_num
                logger.info('Accuracy: {}'.format(acc))
                eval_outputs.append(('accuracy', acc))
            elif metric == 'f1':
                if model.config.num_labels == 2:
                    f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1))
                    logger.info('F1: {}'.format(f1))
                    eval_outputs.append(('f1', f1))
                else:
                    f1 = f1_score(
                        y_trues,
                        np.argmax(logits_list, axis=-1),
                        average='macro')
                    logger.info('Macro F1: {}'.format(f1))
                    eval_outputs.append(('macro-f1', f1))
                    f1 = f1_score(
                        y_trues,
                        np.argmax(logits_list, axis=-1),
                        average='micro')
                    logger.info('Micro F1: {}'.format(f1))
                    eval_outputs.append(('micro-f1', f1))
            else:
                raise NotImplementedError('Metric %s not implemented' % metric)
        # copy from easynlp (end)

        return dict(eval_outputs)
