# Copyright (c) Alibaba, Inc. and its affiliates.

import math
import os.path as osp
from typing import Any, Dict

import torch
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.models.cv.tinynas_classfication import get_zennet
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
    Tasks.image_classification, module_name=Pipelines.tinynas_classification)
class TinynasClassificationPipeline(Pipeline):

    def __init__(self, model: str, **kwargs):
        """
        use `model` to create a tinynas classification pipeline for prediction
        Args:
            model: model id on modelscope hub.
        """
        super().__init__(model=model, **kwargs)
        self.path = model
        self.model = get_zennet()

        model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE)

        checkpoint = torch.load(
            model_pth_path, map_location='cpu', weights_only=True)
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint

        self.model.load_state_dict(state_dict, strict=True)
        logger.info('load model done')

    def preprocess(self, input: Input) -> Dict[str, Any]:
        img = LoadImage.convert_to_img(input)

        input_image_size = 224
        crop_image_size = 380
        input_image_crop = 0.875
        resize_image_size = int(math.ceil(crop_image_size / input_image_crop))
        transforms_normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        transform_list = [
            transforms.Resize(
                resize_image_size,
                interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(crop_image_size),
            transforms.ToTensor(), transforms_normalize
        ]
        transformer = transforms.Compose(transform_list)

        img = transformer(img)
        img = torch.unsqueeze(img, 0)
        img = torch.nn.functional.interpolate(
            img, input_image_size, mode='bilinear')
        result = {'img': img}

        return result

    def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
        is_train = False
        if is_train:
            self.model.train()
        else:
            self.model.eval()

        outputs = self.model(input['img'])
        return {'outputs': outputs}

    def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        label_mapping_path = osp.join(self.path, 'label_map.txt')
        f = open(label_mapping_path, encoding='utf-8')
        content = f.read()
        f.close()
        label_dict = eval(content)

        output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1)
        score = torch.max(output_prob)
        output_dict = {
            OutputKeys.SCORES: [score.item()],
            OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]]
        }
        return output_dict
