# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
import tempfile
from typing import Any, Dict

import numpy as np
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.motion_generation import (ClassifierFreeSampleModel,
                                                    create_model,
                                                    load_model_wo_clip)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.cv.motion_utils.motion_process import recover_from_ric
from modelscope.utils.cv.motion_utils.plot_script import plot_3d_motion
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
    Tasks.motion_generation, module_name=Pipelines.motion_generattion)
class MDMMotionGeneration(Pipeline):

    def __init__(self, model: str, **kwargs):
        """
        use `model` to create motion generation pipeline for prediction
        Args:
            model: model id on modelscope hub.
        """
        super().__init__(model=model, **kwargs)
        model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
        logger.info(f'loading model from {model_path}')
        config_path = osp.join(self.model, ModelFile.CONFIGURATION)
        logger.info(f'loading config from {config_path}')
        self.mean = np.load(osp.join(self.model, 'Mean.npy'))
        self.std = np.load(osp.join(self.model, 'Std.npy'))
        self.cfg = Config.from_file(config_path)
        self.cfg.update({'smpl_data_path': osp.join(self.model, 'smpl')})
        self.cfg.update(kwargs)
        self.n_joints = 22
        self.fps = 20
        self.n_frames = 120
        self.mdm, self.diffusion = create_model(self.cfg)
        state_dict = torch.load(
            model_path, map_location='cpu', weights_only=True)
        load_model_wo_clip(self.mdm, state_dict)
        self.mdm = ClassifierFreeSampleModel(self.mdm)
        self.mdm.to(self.device)
        self.mdm.eval()
        logger.info('load model done')

    def preprocess(self, input: Input) -> Dict[str, Any]:
        if isinstance(input, str):
            input_text = input
        else:
            raise TypeError(f'input should be a str,'
                            f'  but got {type(input)}')
        result = {'input_text': input_text}
        return result

    def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
        texts = [input['input_text']]
        model_kwargs = {
            'y': {
                'mask': torch.ones(1, 1, 1, self.n_frames) > 0,
                'lengths': torch.tensor([self.n_frames]),
                'tokens': None,
                'text': texts,
                'scale': torch.ones(1, device=self.device) * 2.5
            }
        }
        sample_fn = self.diffusion.p_sample_loop
        sample = sample_fn(
            self.mdm,
            (1, self.mdm.njoints, self.mdm.nfeats, self.n_frames),
            clip_denoised=False,
            model_kwargs=model_kwargs,
            skip_timesteps=0,
            init_image=None,
            progress=True,
            dump_steps=None,
            noise=None,
            const_noise=False,
        )
        sample = (sample.cpu().permute(0, 2, 3, 1) * self.std
                  + self.mean).float()
        sample = recover_from_ric(sample, self.n_joints)
        sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)

        sample = self.mdm.rot2xyz(
            x=sample,
            mask=None,
            pose_rep='xyz',
            glob=True,
            translation=True,
            jointstype='smpl',
            vertstrans=True,
            betas=None,
            beta=0,
            glob_rot=None,
            get_rotations_back=False)
        motion = sample.cpu().numpy()
        motion = motion[0].transpose(2, 0, 1)
        out = {OutputKeys.KEYPOINTS: motion, 'text': input['input_text']}
        return out

    def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        output_video_path = kwargs.get(
            'output_video',
            tempfile.NamedTemporaryFile(suffix='.mp4').name)
        kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10],
                           [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
                           [9, 13, 16, 18, 20]]
        if output_video_path is not None:
            plot_3d_motion(
                output_video_path,
                kinematic_chain,
                inputs[OutputKeys.KEYPOINTS],
                inputs.pop('text'),
                dataset='humanml',
                fps=20)
        inputs.update({OutputKeys.OUTPUT_VIDEO: output_video_path})
        return inputs
