# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Optional, Union

import torch
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['ImageDenoisePipeline']


@PIPELINES.register_module(
    Tasks.image_denoising, module_name=Pipelines.image_denoise)
class ImageDenoisePipeline(Pipeline):

    def __init__(self,
                 model: Union[NAFNetForImageDenoise, str],
                 preprocessor: Optional[ImageDenoisePreprocessor] = None,
                 **kwargs):
        """
        use `model` and `preprocessor` to create a cv image denoise pipeline for prediction
        Args:
            model: model id on modelscope hub.
        """
        super().__init__(model=model, preprocessor=preprocessor, **kwargs)
        self.model.eval()
        self.config = self.model.config

        if torch.cuda.is_available():
            self._device = torch.device('cuda')
        else:
            self._device = torch.device('cpu')
        logger.info('load image denoise model done')

    def preprocess(self, input: Input) -> Dict[str, Any]:
        img = LoadImage.convert_to_img(input)
        test_transforms = transforms.Compose([transforms.ToTensor()])
        img = test_transforms(img)
        result = {'img': img.unsqueeze(0).to(self._device)}
        return result

    def crop_process(self, input):
        output = torch.zeros_like(input)  # [1, C, H, W]
        # determine crop_h and crop_w
        ih, iw = input.shape[-2:]
        crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1)
        overlap = 16

        step_h, step_w = ih // crop_rows, iw // crop_cols
        for y in range(crop_rows):
            for x in range(crop_cols):
                crop_y = step_h * y
                crop_x = step_w * x

                crop_h = step_h if y < crop_rows - 1 else ih - crop_y
                crop_w = step_w if x < crop_cols - 1 else iw - crop_x

                crop_frames = input[:, :,
                                    max(0, crop_y - overlap
                                        ):min(crop_y + crop_h + overlap, ih),
                                    max(0, crop_x - overlap
                                        ):min(crop_x + crop_w
                                              + overlap, iw)].contiguous()
                h_start = overlap if max(0, crop_y - overlap) > 0 else 0
                w_start = overlap if max(0, crop_x - overlap) > 0 else 0
                h_end = h_start + crop_h if min(crop_y + crop_h
                                                + overlap, ih) < ih else ih
                w_end = w_start + crop_w if min(crop_x + crop_w
                                                + overlap, iw) < iw else iw

                output[:, :, crop_y:crop_y + crop_h,
                       crop_x:crop_x + crop_w] = self.model._inference_forward(
                           crop_frames)['outputs'][:, :, h_start:h_end,
                                                   w_start:w_end]
        return output

    def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

        def set_phase(model, is_train):
            if is_train:
                model.train()
            else:
                model.eval()

        is_train = False
        set_phase(self.model, is_train)
        with torch.no_grad():
            output = self.crop_process(input['img'])  # output Tensor

        return {'output_tensor': output}

    def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
        output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
            1, 2, 0).numpy().astype('uint8')
        return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}
