# Copyright (c) 2022 Zhipu.AI
import copy
import os
import random
import re
import stat
import sys
import time
from functools import partial
from typing import Any, Dict, List, Tuple

import torch
from SwissArmyTransformer import mpu
from SwissArmyTransformer.generation.autoregressive_sampling import (
    get_masks_and_position_ids_default, update_mems)
from SwissArmyTransformer.generation.utils import (generate_continually,
                                                   timed_name)

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .generation import BaseStrategy, BeamSearchStrategy
from .initialize import initialize, initialize_model_and_tokenizer

torch.set_num_threads(24)

logger = get_logger()


def batch_filling_sequence(
        model,
        seqs,
        context_lengths,
        strategy,
        max_memory_length=100000,
        get_masks_and_position_ids=get_masks_and_position_ids_default,
        mems=None,
        **kw_args):
    '''
        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
            cache, should be first mems.shape[1] parts of context_tokens.
            mems are the first-level citizens here, but we don't assume what is memorized.
            input mems are used when multi-phase generation.
    '''
    assert len(seqs.shape) == 2

    # building the initial tokens, attention_mask, and position_ids
    batch_size, context_length = seqs.shape
    seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
    tokens = seqs[..., :context_length]
    if attention_mask.dtype != torch.bool:
        attention_mask = attention_mask.type_as(next(
            model.parameters()))  # if fp16
    # initialize generation
    counter = context_length - 1  # Last fixed index is ``counter''
    index = 0 if mems is None else mems.shape[
        2]  # Next forward starting index, also the length of cache.
    num_beams = 1
    # step-by-step generation
    while counter < seqs.shape[1] - 1:
        # Now, we want to generate seq[counter + 1],
        # token[:, index: counter+1] needs forwarding.
        # forward
        tokens = tokens.reshape(batch_size * num_beams, -1)
        mems = mems.reshape(mems.shape[0], batch_size
                            * num_beams, mems.shape[-2],
                            mems.shape[-1]) if mems is not None else None
        logits, *output_per_layers = model(
            tokens[:, index:],
            position_ids[..., index:counter + 1],
            attention_mask[...,
                           index:counter + 1, :counter + 1],  # TODO memlen
            mems=mems,
            **kw_args)
        mem_kv = [o['mem_kv'] for o in output_per_layers]
        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
        if counter == context_length - 1:
            logits = logits[torch.arange(batch_size), context_lengths - 1]
        else:
            logits = logits[:, -1]
        counter += 1
        index = counter

        # sampling
        logits = logits.reshape(batch_size, num_beams, -1)
        tokens = tokens.reshape(batch_size, num_beams, -1)
        mems = mems.reshape(mems.shape[0], batch_size, num_beams,
                            mems.shape[-2], mems.shape[-1])
        tokens, mems = strategy.forward(logits, tokens, mems)
        if len(tokens.shape) == 3 and num_beams == 1:
            num_beams = tokens.shape[1]
            position_ids = position_ids.unsqueeze(1).expand(
                batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
            attention_mask_shape = attention_mask.shape[-3:]
            attention_mask = attention_mask.unsqueeze(1).expand(
                batch_size, num_beams, -1, -1,
                -1).reshape(batch_size * num_beams, *attention_mask_shape)
        if strategy.is_done:
            break
    return strategy.finalize(tokens, mems)


def add_generation_specific_args(parser):
    parser.add_argument(
        '--sampling-strategy',
        type=str,
        default='BaseStrategy',
        help='Type of sampling strategy.')
    parser.add_argument(
        '--min-gen-length',
        type=int,
        default=0,
        help='The minimum length each blank should generate.')
    parser.add_argument(
        '--print-all-beams',
        action='store_true',
        help='Print all output generated by beam search strategy.')


def isEnglish(s):
    try:
        s.encode(encoding='utf-8').decode('ascii')
    except UnicodeDecodeError:
        return False
    else:
        return True


def get_masks_and_position_ids(seq,
                               mask_position,
                               max_gen_length,
                               gmask=False):
    context_length = seq.shape[1]
    tokens = torch.nn.functional.pad(
        seq, (0, max_gen_length), mode='constant', value=-1)
    attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]),
                                device=tokens.device)
    attention_mask.tril_()
    attention_mask[..., :context_length - 1] = 1
    attention_mask.unsqueeze_(1)
    attention_mask = (attention_mask < 0.5).bool()

    position_ids = torch.arange(
        tokens.shape[-1], dtype=torch.long, device=tokens.device)
    if not gmask:
        position_ids[context_length - 1:] = mask_position

    position_ids = position_ids.unsqueeze(0)

    return tokens, attention_mask, position_ids


def fill_blanks(args, raw_text: str, model, tokenizer,
                strategy) -> Tuple[List[str], List[str], List[List[str]]]:
    # add MASK
    generation_mask = '[gMASK]'
    if '[MASK]' in raw_text:
        generation_mask = '[MASK]'
    elif '[sMASK]' in raw_text:
        generation_mask = '[sMASK]'
    use_gmask = '[MASK]' not in raw_text and '[sMASK]' not in raw_text

    mask_pattern = r'\[[sg]?MASK\]'
    text_list = re.split(mask_pattern, raw_text)
    pattern_list = re.compile(mask_pattern).findall(raw_text)
    seq = []
    for i in range(len(pattern_list)):
        pattern = pattern_list[i]
        sub_text = text_list[i]
        seq.extend(tokenizer.tokenize(sub_text))
        seq.append(tokenizer.get_command(pattern))

    seq.extend(tokenizer.tokenize(text_list[-1]))

    if 'MASK]' not in raw_text:
        seq += [tokenizer.get_command(generation_mask)]
        raw_text += ' ' + generation_mask
    if not raw_text.endswith('MASK]'):
        seq = seq + [tokenizer.get_command('eos')]
    if mpu.get_model_parallel_rank() == 0:
        logger.info('\nInput: {}\n'.format(raw_text))
    if len(seq) > args.max_sequence_length:
        raise ValueError('text too long.')

    # generation
    is_english = isEnglish(raw_text)
    output_list = [seq]
    num_output = args.num_beams if args.sampling_strategy == 'BeamSearchStrategy' else 1
    last_pos, answers, answers_with_style, blanks = (
        [0] * num_output,
        ['' for _ in range(num_output)],
        ['' for _ in range(num_output)],
        [[] for _ in range(num_output)],
    )

    # continually detect the first mark position
    while True:
        seq = output_list[0]
        # detect mask position
        mask_token = tokenizer.get_command(generation_mask)
        if mask_token not in seq:
            break
        mask_position = seq.index(mask_token)

        output_list = []

        input_seq = torch.cuda.LongTensor(
            [seq + [tokenizer.get_command('sop')]],
            device=args.device,
        )
        output, _ = batch_filling_sequence(
            model,
            input_seq,
            torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
            strategy=strategy,
            get_masks_and_position_ids=partial(
                get_masks_and_position_ids,
                mask_position=mask_position,
                max_gen_length=args.out_seq_length - input_seq.shape[-1],
                gmask=use_gmask,
            ),
        )
        if isinstance(output, torch.Tensor):  # different strategies
            output = output.tolist()
        output = output[0]  # batch_size = 1
        output_list.extend(output)

        # clip -1s and fill back generated things into seq
        for i in range(len(output_list)):
            output = output_list[i].tolist() if isinstance(
                output_list[i], torch.Tensor) else output_list[i]
            try:
                unfinished = output.index(-1)
            except ValueError:
                unfinished = len(output)
            if output[unfinished - 1] in strategy.end_tokens:
                unfinished -= 1
            bog = output.index(tokenizer.get_command('sop'))

            prefix = tokenizer.detokenize(output[last_pos[i]:mask_position])
            blank = tokenizer.detokenize(output[bog + 1:unfinished])
            answers_with_style[i] += (
                prefix + (' ' if is_english else '') +  # noqa
                ('\033[4m' if use_gmask else '\x1b[0;32m\033[4m') + blank
                +  # noqa
                ('\033[0m' if use_gmask else '\033[0m\x1b[0m') +  # noqa
                (' ' if is_english else ''))  # noqa
            blanks[i].append(blank)
            last_pos[i] = mask_position + unfinished - (bog + 1)
            output_list[i] = output[:mask_position] + output[
                bog + 1:unfinished] + output[mask_position + 1:bog]

    for i, output in enumerate(output_list):
        if output[-1] == tokenizer.get_command('eos'):
            output = output[:-1]
        answers_with_style[i] += tokenizer.detokenize(output[last_pos[i]:])
        answers[i] = tokenizer.detokenize(output)

    return answers, answers_with_style, blanks


@MODELS.register_module(Tasks.text_generation, module_name=Models.glm130b)
class GLM130bForTextGeneration(TorchModel):

    def __init__(self, model_dir: str, *args, **kwargs):
        # """initialize the glm130b model from the `model_dir` path.

        # Args:
        #     model_dir (str): the model path.
        # """
        super().__init__(model_dir, *args, **kwargs)
        self.cfg = Config.from_file(model_dir + '/' + ModelFile.CONFIGURATION)
        args = initialize(extra_args_provider=add_generation_specific_args)
        args.seed = random.randint(1, sys.maxsize - 1)
        args.sampling_strategy = self.cfg.model.sampling_strategy
        args.out_seq_length = self.cfg.model.out_seq_length
        args.min_gen_length = self.cfg.model.min_gen_length
        args.num_beams = self.cfg.model.num_beams
        args.length_penalty = self.cfg.model.length_penalty
        args.no_repeat_ngram_size = self.cfg.model.no_repeat_ngram_size
        args.temperature = self.cfg.model.temperature
        args.top_k = self.cfg.model.top_k
        args.top_p = self.cfg.model.top_p
        args.load = model_dir

        logger.info('Loading model and tokenizer ...')
        self.model, self.tokenizer = initialize_model_and_tokenizer(args)

        end_tokens = [
            self.tokenizer.get_command('eop'),
            self.tokenizer.get_command('eos')
        ]

        if args.sampling_strategy == 'BaseStrategy':
            self.strategy = BaseStrategy(
                batch_size=1,
                temperature=args.temperature,
                top_k=args.top_k,
                top_p=args.top_p,
                end_tokens=end_tokens)
        elif args.sampling_strategy == 'BeamSearchStrategy':
            self.strategy = BeamSearchStrategy(
                1,
                args.num_beams,
                length_penalty=args.length_penalty,
                consider_end=True,
                end_tokens=end_tokens,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
                min_gen_length=args.min_gen_length,
            )
        else:
            raise ValueError(f'unknown strategy {args.sampling_strategy}')

        self.args = args

    def func(self, raw_text):
        answers, answers_with_style, blanks = fill_blanks(
            self.args, raw_text, self.model, self.tokenizer, self.strategy)

        if mpu.get_model_parallel_rank() == 0:
            logger.info('Output:' + str(answers_with_style[0]))

        return str(answers_with_style[0])

    def forward(self, input: str) -> Dict[str, str]:
        raw_text, is_stop = '', False
        if torch.distributed.get_rank() == 0:
            raw_text = input
            if not raw_text:
                return {OutputKeys.TEXT: 'Query should not be empty!'}
            if raw_text == 'stop':
                is_stop = True
            torch.distributed.broadcast_object_list([raw_text, is_stop])
        else:
            info = [raw_text, is_stop]
            torch.distributed.broadcast_object_list(info)
            raw_text, is_stop = info
        if is_stop:
            return
        try:
            start_time = time.time()
            res = self.func(raw_text)
            if torch.distributed.get_rank() == 0:
                logger.info('\nTaken time {:.2f}\n'.format(time.time()
                                                           - start_time))
        except (ValueError, FileNotFoundError) as e:
            return {OutputKeys.TEXT: str(e)}
        logger.info('Generation finished.')
        return {OutputKeys.TEXT: res}
