# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os.path as osp
from typing import Any, Dict

import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models import Model
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.diffusion.diffusion import (
    GaussianDiffusion, beta_schedule)
from modelscope.models.multi_modal.diffusion.structbert import (BertConfig,
                                                                BertModel)
from modelscope.models.multi_modal.diffusion.tokenizer import FullTokenizer
from modelscope.models.multi_modal.diffusion.unet_generator import \
    DiffusionGenerator
from modelscope.models.multi_modal.diffusion.unet_upsampler_256 import \
    SuperResUNet256
from modelscope.models.multi_modal.diffusion.unet_upsampler_1024 import \
    SuperResUNet1024
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['DiffusionForTextToImageSynthesis']


def make_diffusion(schedule,
                   num_timesteps=1000,
                   init_beta=None,
                   last_beta=None,
                   var_type='fixed_small'):
    betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta)
    diffusion = GaussianDiffusion(betas, var_type=var_type)
    return diffusion


class Tokenizer(object):

    def __init__(self, vocab_file, seq_len=64):
        self.vocab_file = vocab_file
        self.seq_len = seq_len
        self.tokenizer = FullTokenizer(
            vocab_file=vocab_file, do_lower_case=True)

    def __call__(self, text):
        # tokenization
        tokens = self.tokenizer.tokenize(text)
        tokens = ['[CLS]'] + tokens[:self.seq_len - 2] + ['[SEP]']
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)
        segment_ids = [0] * len(input_ids)

        # padding
        input_ids += [0] * (self.seq_len - len(input_ids))
        input_mask += [0] * (self.seq_len - len(input_mask))
        segment_ids += [0] * (self.seq_len - len(segment_ids))
        assert len(input_ids) == len(input_mask) == len(
            segment_ids) == self.seq_len

        # convert to tensors
        input_ids = torch.LongTensor(input_ids)
        input_mask = torch.LongTensor(input_mask)
        segment_ids = torch.LongTensor(segment_ids)
        return input_ids, segment_ids, input_mask


class DiffusionModel(nn.Module):

    def __init__(self, model_dir):
        super(DiffusionModel, self).__init__()
        # including text and generator config
        model_config = json.load(
            open('{}/model_config.json'.format(model_dir), encoding='utf-8'))

        # text encoder
        text_config = model_config['text_config']
        self.text_encoder = BertModel(BertConfig.from_dict(text_config))

        # generator (64x64)
        generator_config = model_config['generator_config']
        self.unet_generator = DiffusionGenerator(**generator_config)

        # upsampler (256x256)
        upsampler_256_config = model_config['upsampler_256_config']
        self.unet_upsampler_256 = SuperResUNet256(**upsampler_256_config)

        # upsampler (1024x1024)
        upsampler_1024_config = model_config['upsampler_1024_config']
        self.unet_upsampler_1024 = SuperResUNet1024(**upsampler_1024_config)

    def forward(self, noise, timesteps, input_ids, token_type_ids,
                attention_mask):
        context, y = self.text_encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask)
        context = context[-1]
        x = self.unet_generator(noise, timesteps, y, context, attention_mask)
        x = self.unet_upsampler_256(noise, timesteps, x,
                                    torch.zeros_like(timesteps), y, context,
                                    attention_mask)
        x = self.unet_upsampler_1024(x, t, x)
        return x


@MODELS.register_module(
    Tasks.text_to_image_synthesis, module_name=Models.diffusion)
class DiffusionForTextToImageSynthesis(Model):

    def __init__(self, model_dir, device='gpu', **kwargs):
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        super().__init__(model_dir=model_dir, device=device, **kwargs)
        diffusion_model = DiffusionModel(model_dir=model_dir)
        pretrained_params = torch.load(
            osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
        diffusion_model.load_state_dict(pretrained_params)
        diffusion_model.eval().to()

        self.device = create_device(device)
        diffusion_model.to(self.device)

        # modules
        self.text_encoder = diffusion_model.text_encoder
        self.unet_generator = diffusion_model.unet_generator
        self.unet_upsampler_256 = diffusion_model.unet_upsampler_256
        self.unet_upsampler_1024 = diffusion_model.unet_upsampler_1024

        # text tokenizer
        vocab_path = f'{model_dir}/{ModelFile.VOCAB_FILE}'
        self.tokenizer = Tokenizer(vocab_file=vocab_path, seq_len=64)

        # diffusion process
        diffusion_params = json.load(
            open(
                '{}/diffusion_config.json'.format(model_dir),
                encoding='utf-8'))
        self.diffusion_generator = make_diffusion(
            **diffusion_params['generator_config'])
        self.diffusion_upsampler_256 = make_diffusion(
            **diffusion_params['upsampler_256_config'])
        self.diffusion_upsampler_1024 = make_diffusion(
            **diffusion_params['upsampler_1024_config'])

    def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
        if not all([key in input for key in ('text', 'noise', 'timesteps')]):
            raise ValueError(
                f'input should contains "text", "noise", and "timesteps", but got {input.keys()}'
            )
        input_ids, token_type_ids, attention_mask = self.tokenizer(
            input['text'])
        input_ids = input_ids.to(self.device).unsqueeze(0)
        token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
        attention_mask = attention_mask.to(self.device).unsqueeze(0)
        context, y = self.text_encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask)
        context = context[-1]
        x = self.unet_generator(noise, timesteps, y, context, attention_mask)
        x = self.unet_upsampler_256(noise, timesteps, x,
                                    torch.zeros_like(timesteps), y, context,
                                    attention_mask)
        x = self.unet_upsampler_1024(x, t, x)
        img = x.clamp(-1, 1).add(1).mul(127.5)
        img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        return img

    def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        return inputs

    @torch.no_grad()
    def generate(self, input: Dict[str, Any]) -> Dict[str, Any]:
        if 'text' not in input:
            raise ValueError(
                f'input should contain "text", but got {input.keys()}')

        # encode text
        input_ids, token_type_ids, attention_mask = self.tokenizer(
            input['text'])
        input_ids = input_ids.to(self.device).unsqueeze(0)
        token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
        attention_mask = attention_mask.to(self.device).unsqueeze(0)
        context, y = self.text_encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask)
        context = context[-1]

        # choose a proper solver
        solver = input.get('solver', 'dpm-solver')
        if solver == 'dpm-solver':
            # generation
            img = self.diffusion_generator.dpm_solver_sample_loop(
                noise=torch.randn(1, 3, 64, 64).to(self.device),
                model=self.unet_generator,
                model_kwargs=[{
                    'y': y,
                    'context': context,
                    'mask': attention_mask
                }, {
                    'y': torch.zeros_like(y),
                    'context': torch.zeros_like(context),
                    'mask': attention_mask
                }],
                percentile=input.get('generator_percentile', 0.995),
                guide_scale=input.get('generator_guide_scale', 5.0),
                dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20),
                order=3,
                skip_type='logSNR',
                method='singlestep',
                t_start=0.9946)

            # upsampling (64->256)
            if not input.get('debug', False):
                img = F.interpolate(
                    img,
                    scale_factor=4.0,
                    mode='bilinear',
                    align_corners=False)
            img = self.diffusion_upsampler_256.dpm_solver_sample_loop(
                noise=torch.randn_like(img),
                model=self.unet_upsampler_256,
                model_kwargs=[{
                    'lx': img,
                    'lt': torch.zeros(1).to(self.device),
                    'y': y,
                    'context': context,
                    'mask': attention_mask
                }, {
                    'lx': img,
                    'lt': torch.zeros(1).to(self.device),
                    'y': torch.zeros_like(y),
                    'context': torch.zeros_like(context),
                    'mask': torch.zeros_like(attention_mask)
                }],
                percentile=input.get('upsampler_256_percentile', 0.995),
                guide_scale=input.get('upsampler_256_guide_scale', 5.0),
                dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20),
                order=3,
                skip_type='logSNR',
                method='singlestep',
                t_start=0.9946)

            # upsampling (256->1024)
            if not input.get('debug', False):
                img = F.interpolate(
                    img,
                    scale_factor=4.0,
                    mode='bilinear',
                    align_corners=False)
            img = self.diffusion_upsampler_1024.dpm_solver_sample_loop(
                noise=torch.randn_like(img),
                model=self.unet_upsampler_256,
                model_kwargs=[{
                    'lx': img,
                    'lt': torch.zeros(1).to(self.device),
                    'y': y,
                    'context': context,
                    'mask': attention_mask
                }, {
                    'lx': img,
                    'lt': torch.zeros(1).to(self.device),
                    'y': torch.zeros_like(y),
                    'context': torch.zeros_like(context),
                    'mask': torch.zeros_like(attention_mask)
                }],
                percentile=input.get('upsampler_256_percentile', 0.995),
                guide_scale=input.get('upsampler_256_guide_scale', 5.0),
                dpm_solver_timesteps=input.get('dpm_solver_timesteps', 10),
                order=3,
                skip_type='logSNR',
                method='singlestep',
                t_start=None)
        elif solver == 'ddim':
            # generation
            img = self.diffusion_generator.ddim_sample_loop(
                noise=torch.randn(1, 3, 64, 64).to(self.device),
                model=self.unet_generator,
                model_kwargs=[{
                    'y': y,
                    'context': context,
                    'mask': attention_mask
                }, {
                    'y': torch.zeros_like(y),
                    'context': torch.zeros_like(context),
                    'mask': attention_mask
                }],
                percentile=input.get('generator_percentile', 0.995),
                guide_scale=input.get('generator_guide_scale', 5.0),
                ddim_timesteps=input.get('generator_ddim_timesteps', 250),
                eta=input.get('generator_ddim_eta', 0.0))

            # upsampling (64->256)
            if not input.get('debug', False):
                img = F.interpolate(
                    img,
                    scale_factor=4.0,
                    mode='bilinear',
                    align_corners=False)
            img = self.diffusion_upsampler_256.ddim_sample_loop(
                noise=torch.randn_like(img),
                model=self.unet_upsampler_256,
                model_kwargs=[{
                    'lx': img,
                    'lt': torch.zeros(1).to(self.device),
                    'y': y,
                    'context': context,
                    'mask': attention_mask
                }, {
                    'lx': img,
                    'lt': torch.zeros(1).to(self.device),
                    'y': torch.zeros_like(y),
                    'context': torch.zeros_like(context),
                    'mask': torch.zeros_like(attention_mask)
                }],
                percentile=input.get('upsampler_256_percentile', 0.995),
                guide_scale=input.get('upsampler_256_guide_scale', 5.0),
                ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50),
                eta=input.get('upsampler_256_ddim_eta', 0.0))

            # upsampling (256->1024)
            if not input.get('debug', False):
                img = F.interpolate(
                    img,
                    scale_factor=4.0,
                    mode='bilinear',
                    align_corners=False)
            img = self.diffusion_upsampler_1024.ddim_sample_loop(
                noise=torch.randn_like(img),
                model=self.unet_upsampler_1024,
                model_kwargs={'concat': img},
                percentile=input.get('upsampler_1024_percentile', 0.995),
                ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20),
                eta=input.get('upsampler_1024_ddim_eta', 0.0))
        else:
            raise ValueError(
                'currently only supports "ddim" and "dpm-solve" solvers')

        # output
        img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute(
            1, 2, 0).cpu().numpy().astype(np.uint8)
        return img
