# Copyright (c) Alibaba, Inc. and its affiliates.

import math
import os.path as osp
from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.action_recognition import (BaseVideoModel,
                                                     PatchShiftTransformer)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import ReadVideoData
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
    Tasks.action_recognition, module_name=Pipelines.action_recognition)
class ActionRecognitionPipeline(Pipeline):

    def __init__(self, model: str, **kwargs):
        """
        use `model` to create a action recognition pipeline for prediction
        Args:
            model: model id on modelscope hub.
        """
        super().__init__(model=model, **kwargs)
        model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
        logger.info(f'loading model from {model_path}')
        config_path = osp.join(self.model, ModelFile.CONFIGURATION)
        logger.info(f'loading config from {config_path}')
        self.cfg = Config.from_file(config_path)

        self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
        self.infer_model.eval()
        self.infer_model.load_state_dict(
            torch.load(
                model_path, map_location=self.device,
                weights_only=True)['model_state'])
        self.label_mapping = self.cfg.label_mapping
        logger.info('load model done')

    def preprocess(self, input: Input) -> Dict[str, Any]:
        if isinstance(input, str):
            video_input_data = ReadVideoData(self.cfg, input).to(self.device)
        else:
            raise TypeError(f'input should be a str,'
                            f'  but got {type(input)}')
        result = {'video_data': video_input_data}
        return result

    def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
        pred = self.perform_inference(input['video_data'])
        output_label = self.label_mapping[str(pred)]
        return {OutputKeys.LABELS: output_label}

    @torch.no_grad()
    def perform_inference(self, data, max_bsz=4):
        iter_num = math.ceil(data.size(0) / max_bsz)
        preds_list = []
        for i in range(iter_num):
            preds_list.append(
                self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0])
        pred = torch.cat(preds_list, dim=0)
        return pred.mean(dim=0).argmax().item()

    def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        return inputs


@PIPELINES.register_module(
    Tasks.action_recognition, module_name=Pipelines.pst_action_recognition)
class PSTActionRecognitionPipeline(Pipeline):

    def __init__(self, model: str, **kwargs):
        """
        use `model` to create a PST action recognition pipeline for prediction
        Args:
            model: model id on modelscope hub.
        """
        super().__init__(model=model, **kwargs)
        model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
        logger.info(f'loading model from {model_path}')
        config_path = osp.join(self.model, ModelFile.CONFIGURATION)
        logger.info(f'loading config from {config_path}')
        self.cfg = Config.from_file(config_path)
        self.infer_model = PatchShiftTransformer(model).to(self.device)
        self.infer_model.eval()
        self.infer_model.load_state_dict(
            torch.load(
                model_path, map_location=self.device,
                weights_only=True)['state_dict'])
        self.label_mapping = self.cfg.label_mapping
        logger.info('load model done')

    def preprocess(self, input: Input) -> Dict[str, Any]:
        if isinstance(input, str):
            video_input_data = ReadVideoData(self.cfg, input).to(self.device)
        else:
            raise TypeError(f'input should be a str,'
                            f'  but got {type(input)}')
        result = {'video_data': video_input_data}
        return result

    def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
        pred = self.perform_inference(input['video_data'])
        output_label = self.label_mapping[str(pred)]
        return {OutputKeys.LABELS: output_label}

    @torch.no_grad()
    def perform_inference(self, data, max_bsz=4):
        iter_num = math.ceil(data.size(0) / max_bsz)
        preds_list = []
        for i in range(iter_num):
            preds_list.append(
                self.infer_model(data[i * max_bsz:(i + 1) * max_bsz]))
        pred = torch.cat(preds_list, dim=0)
        return pred.mean(dim=0).argmax().item()

    def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        return inputs
