# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
from typing import Any, Dict, Union

import cv2
import numpy as np
import PIL
import torch
from PIL import Image

from modelscope.metainfo import Pipelines
from modelscope.models.cv.virual_tryon import SDAFNet_Tryon
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks


@PIPELINES.register_module(
    Tasks.virtual_try_on, module_name=Pipelines.virtual_try_on)
class VirtualTryonPipeline(Pipeline):

    def __init__(self, model: str, **kwargs):
        """
        use `model` to create a virtual tryon pipeline for prediction
        Args:
            model: model id on modelscope hub.
        """
        super().__init__(model=model)
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        def filter_param(src_params, own_state):
            copied_keys = []
            for name, param in src_params.items():
                if 'module.' == name[0:7]:
                    name = name[7:]
                if '.module.' not in list(own_state.keys())[0]:
                    name = name.replace('.module.', '.')
                if (name in own_state) and (own_state[name].shape
                                            == param.shape):
                    own_state[name].copy_(param)
                    copied_keys.append(name)

        def load_pretrained(model, src_params):
            if 'state_dict' in src_params:
                src_params = src_params['state_dict']
            own_state = model.state_dict()
            filter_param(src_params, own_state)
            model.load_state_dict(own_state)

        self.local_path = self.model
        src_params = torch.load(
            osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE),
            'cpu',
            weights_only=True)
        self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device)
        load_pretrained(self.model, src_params)
        self.model.eval()
        self.size = 192
        from torchvision import transforms
        self.test_transforms = transforms.Compose([
            transforms.Resize(self.size, interpolation=2),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def preprocess(self, input: Union[Dict[str, Any],
                                      tuple]) -> Dict[str, Any]:
        if isinstance(input, tuple):
            index_model = 0
            index_pose = 1
            index_cloth = 2
        else:
            index_model = 'masked_model'
            index_pose = 'pose'
            index_cloth = 'cloth'
        if isinstance(input[index_model], str):
            img_agnostic = load_image(input[index_model])
            pose = load_image(input[index_pose])
            cloth_img = load_image(input[index_cloth])
        elif isinstance(input[index_model], PIL.Image.Image):
            img_agnostic = input[index_model].convert('RGB')
            pose = input[index_pose].convert('RGB')
            cloth_img = input[index_cloth].convert('RGB')
        elif isinstance(input[index_model], np.ndarray):
            if len(input.shape) == 2:
                img_agnostic = cv2.cvtColor(input[index_model],
                                            cv2.COLOR_GRAY2BGR)
                pose = cv2.cvtColor(input[index_pose], cv2.COLOR_GRAY2BGR)
                cloth_img = cv2.cvtColor(input[index_cloth],
                                         cv2.COLOR_GRAY2BGR)
            img_agnostic = Image.fromarray(
                img_agnostic[:, :, ::-1].astype('uint8')).convert('RGB')
            pose = Image.fromarray(
                pose[:, :, ::-1].astype('uint8')).convert('RGB')
            cloth_img = Image.fromarray(
                cloth_img[:, :, ::-1].astype('uint8')).convert('RGB')
        else:
            raise TypeError(f'input should be either str, PIL.Image,'
                            f' np.array, but got {type(input)}')

        img_agnostic = self.test_transforms(img_agnostic)
        pose = self.test_transforms(pose)
        cloth_img = self.test_transforms(cloth_img)
        inputs = {
            'masked_model': img_agnostic.unsqueeze(0),
            'pose': pose.unsqueeze(0),
            'cloth': cloth_img.unsqueeze(0)
        }
        return inputs

    def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

        img_agnostic = inputs['masked_model'].to(self.device)
        pose = inputs['pose'].to(self.device)
        cloth_img = inputs['cloth'].to(self.device)
        ref_input = torch.cat((pose, img_agnostic), dim=1)
        tryon_result = self.model(ref_input, cloth_img, img_agnostic)
        return {OutputKeys.OUTPUT_IMG: tryon_result}

    def postprocess(self, outputs: Dict[str, Any]) -> Dict[str, Any]:
        tryon_result = outputs[OutputKeys.OUTPUT_IMG].permute(0, 2, 3,
                                                              1).squeeze(0)
        tryon_result = tryon_result.add(1.).div(2.).mul(255).data.cpu().numpy()
        outputs[OutputKeys.OUTPUT_IMG] = tryon_result
        return outputs
