# Copyright (c) Alibaba, Inc. and its affiliates.

import io
import os
import re
import sys
import time
from collections import OrderedDict
from shutil import copytree, ignore_patterns, rmtree
from typing import Callable, Dict, Optional, Union

import json
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from modelscope.fileio import File, LocalStorage
from modelscope.utils.config import Config, JSONIteratorEncoder
from modelscope.utils.constant import ConfigFields, ModelFile
from modelscope.utils.file_utils import copytree_py37
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import is_master

logger = get_logger()

storage = LocalStorage()


def weights_to_cpu(state_dict):
    """Copy a model state_dict to cpu.

    Args:
        state_dict (OrderedDict): Model weights on GPU.

    Returns:
        OrderedDict: Model weights on GPU.
    """
    state_dict_cpu = OrderedDict()
    for key, val in state_dict.items():
        state_dict_cpu[key] = val.cpu()
    # Keep metadata in state_dict
    state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
    return state_dict_cpu


def save_checkpoint(model: torch.nn.Module,
                    filename: str,
                    optimizer: Optional[Optimizer] = None,
                    lr_scheduler: Optional[_LRScheduler] = None,
                    meta: Optional[dict] = None,
                    with_meta: bool = True,
                    with_model: bool = True) -> None:
    """Save checkpoint to file.

    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
    ``optimizer``. By default, ``meta`` will contain version and time info.

    Args:
        model (Module): Module whose params are to be saved.
        filename (str): Checkpoint filename.
        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
        lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved.
        meta (dict, optional): Metadata to be saved in checkpoint.
        with_meta (bool, optional): Save meta info.
        with_model(bool, optional): Save model states.
    """
    checkpoint = {}
    if not with_meta and not with_model:
        raise ValueError(
            'Save meta by "with_meta=True" or model by "with_model=True"')

    if with_meta:
        if meta is None:
            meta = {}
        elif not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')
        from modelscope import __version__
        meta.update(modelscope=__version__, time=time.asctime())

        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module

        if hasattr(model, 'CLASSES') and model.CLASSES is not None:
            # save class name to the meta
            meta.update(CLASSES=model.CLASSES)

        checkpoint['meta'] = meta

        # save optimizer state dict in the checkpoint
        if isinstance(optimizer, Optimizer):
            checkpoint['optimizer'] = optimizer.state_dict()
        elif isinstance(optimizer, dict):
            checkpoint['optimizer'] = {}
            for name, optim in optimizer.items():
                checkpoint['optimizer'][name] = optim.state_dict()

        # save lr_scheduler state dict in the checkpoint
        if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'):
            checkpoint['lr_scheduler'] = lr_scheduler.state_dict()

    if with_model:
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module

        _weights = weights_to_cpu(model.state_dict())
        if not with_meta:
            checkpoint = _weights
        else:
            checkpoint['state_dict'] = _weights

    with io.BytesIO() as f:
        torch.save(checkpoint, f)
        File.write(f.getvalue(), filename)


def load_checkpoint(filename,
                    model,
                    optimizer: Optimizer = None,
                    lr_scheduler: _LRScheduler = None):
    if not os.path.exists(filename):
        raise ValueError(f'Checkpoint file {filename} does not exist!')
    checkpoint = torch.load(filename, map_location='cpu', weights_only=True)

    if optimizer is not None:
        if 'optimizer' in checkpoint:
            if isinstance(optimizer, Optimizer):
                optimizer.load_state_dict(checkpoint['optimizer'])
            elif isinstance(optimizer, dict):
                optimizer_dict = checkpoint['optimizer']
                for key, optimizer_ins in optimizer.items():
                    if key in optimizer_dict:
                        optimizer_ins.load_state_dict(optimizer_dict[key])
                    else:
                        logger.warning(
                            f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}'
                        )
        else:
            logger.warning(
                f'The state dict of optimizer cannot be found in checkpoint file: {filename}'
            )

    if lr_scheduler is not None:
        if 'lr_scheduler' in checkpoint:
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        else:
            logger.warning(
                f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}'
            )

    if model is not None:
        state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
            'state_dict']
        model.load_state_dict(state_dict)
    return checkpoint.get('meta', {})


def load_task_model_checkpoint(model_to_load,
                               model_local_dir,
                               default_dtype=None,
                               load_state_fn=None,
                               **kwargs):
    """
    Load model checkpoint file and feed the parameters into the model.
    Args:
        model_to_load: The model to be load
        model_local_dir: The actual checkpoint dir on local disk.
        default_dtype: Set the default float type by 'torch.set_default_dtype'
        load_state_fn: An optional load_state_fn used to load state_dict into the model.

    Returns:

    """

    def _add_head_prefix_to_state_dict(state_dicts, head_prefix,
                                       expected_keys_without_head_prefix,
                                       missing_keys):
        new_state_dict = OrderedDict()
        for name, module in state_dicts.items():
            if name in expected_keys_without_head_prefix:
                name_with_head = '.'.join([head_prefix, name])
                new_state_dict[name_with_head] = module
                expected_keys_without_head_prefix.remove(name)
                missing_keys = list(set(missing_keys) - set([name_with_head]))
            else:
                new_state_dict[name] = module

        missing_head_keys = []
        if len(expected_keys_without_head_prefix) > 0:
            missing_head_keys = expected_keys_without_head_prefix.copy()
        return new_state_dict, missing_head_keys, missing_keys

    def _find_mismatched_keys(
        state_dicts,
        model_state_dict,
        loaded_keys,
        prefix,
        add_prefix_to_model,
        remove_prefix_from_model,
        ignore_mismatched_sizes,
    ):
        mismatched_key = []
        if ignore_mismatched_sizes:
            for checkpoint_key in loaded_keys:
                model_key = checkpoint_key
                if remove_prefix_from_model:
                    # The model key starts with `prefix` but `checkpoint_key` doesn't, so we add it.
                    model_key = f'{prefix}.{checkpoint_key}'
                elif add_prefix_to_model:
                    # The model key doesn't start with `prefix` but `checkpoint_key` does, so we remove it.
                    model_key = '.'.join(checkpoint_key.split('.')[1:])

                if model_key in model_state_dict:
                    model_shape = model_state_dict[model_key].shape
                    checkpoint_shape = state_dicts[checkpoint_key].shape
                    if checkpoint_shape != model_shape:
                        mismatched_key.append(
                            (checkpoint_key, state_dicts[checkpoint_key].shape,
                             model_state_dict[model_key].shape))
                        del state_dicts[checkpoint_key]
        return mismatched_key

    def _load_state_dict_into_model(
        model,
        state_dict,
        start_prefix,
        head_prefix_keys,
        load_state_fn=None,
    ):
        # Convert old format to new format if needed from a PyTorch state_dict
        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        error_msgs = []

        if load_state_fn is not None:
            load_state_fn(
                model,
                state_dict,
                prefix=start_prefix,
                head_prefix_keys=head_prefix_keys,
                local_metadata=None,
                error_msgs=error_msgs)
        else:

            def load(module: nn.Module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})
                args = (state_dict, prefix, local_metadata, True, [], [],
                        error_msgs)
                module._load_from_state_dict(*args)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')

            load(model, prefix=start_prefix)

        return error_msgs

    def _load_checkpoint(
        model,
        state_dict,
        load_state_fn,
        ignore_mismatched_sizes,
        _fast_init,
    ):
        # Retrieve missing & unexpected_keys
        model_state_dict = model.state_dict()
        expected_keys = list(model_state_dict.keys())
        keys_from_pretrained = list(state_dict.keys())

        prefix = model.base_model_prefix

        # during loading stage, base model prefix is complicated, should consider remove or add
        if len(prefix) > 0:
            # nlp: encoder, decoder
            pretrained_has_prefix_module = any(
                s.startswith(prefix) for s in keys_from_pretrained)
            model_expects_prefix_module = any(
                s.startswith(prefix) for s in expected_keys)
        else:
            # nlp:encoder-decoder, cv:backbone-head,
            pretrained_has_prefix_module = False
            model_expects_prefix_module = False

        remove_prefix_from_model = not pretrained_has_prefix_module and model_expects_prefix_module
        add_prefix_to_model = pretrained_has_prefix_module and not model_expects_prefix_module

        if remove_prefix_from_model:
            expected_keys_not_base_model_prefixed = [
                s for s in expected_keys if not s.startswith(prefix)
            ]
            expected_keys = [
                '.'.join(s.split('.')[1:]) if s.startswith(prefix) else s
                for s in expected_keys
            ]
        elif add_prefix_to_model:
            # backbone only
            expected_keys = ['.'.join([prefix, s]) for s in expected_keys]
            expected_keys_not_base_model_prefixed = []

        missing_keys = list(set(expected_keys) - set(keys_from_pretrained))
        unexpected_keys = list(set(keys_from_pretrained) - set(expected_keys))

        # during loading stage head prefix is simple, add or not add
        prefix_heads = model.head_prefix
        expected_head_keys_without_head_prefix = []
        missing_head_keys = []
        unexpected_head_keys = []
        pretrained_has_prefix_head = dict()
        head_prefix_keys = dict()

        # only for case of head mismatched with state-dict
        if len(prefix_heads) > 0 and len(unexpected_keys) > 0:
            if isinstance(prefix_heads, str):
                prefix_heads = [prefix_heads]

            # to double-check if head matched with state-dict
            for prefix_head in prefix_heads:
                pretrained_has_prefix_head[prefix_head] = any(
                    s.startswith(prefix_head) for s in keys_from_pretrained)

            for prefix_head in prefix_heads:
                expected_keys_without_head_prefix = [
                    '.'.join(s.split('.')[1:]) for s in expected_keys
                    if s.startswith(prefix_head)
                ]
                expected_head_keys_without_head_prefix.extend(
                    expected_keys_without_head_prefix)
                head_prefix_keys[
                    prefix_head] = expected_keys_without_head_prefix
            unexpected_head_keys = list(
                set(unexpected_keys)
                - set(expected_head_keys_without_head_prefix))
            unexpected_keys = list(
                set(unexpected_keys)
                - set(expected_head_keys_without_head_prefix))

        _keys_to_ignore_on_load_missing = kwargs.pop(
            '_keys_to_ignore_on_load_missing', None)
        _keys_to_ignore_on_load_unexpected = kwargs.pop(
            '_keys_to_ignore_on_load_unexpected', None)
        # Some models may have keys that are not in the state by design, removing them before needlessly warning
        # the user.
        if _keys_to_ignore_on_load_missing is not None:
            for pat in _keys_to_ignore_on_load_missing:
                missing_keys = [
                    k for k in missing_keys if re.search(pat, k) is None
                ]

        if _keys_to_ignore_on_load_unexpected is not None:
            for pat in _keys_to_ignore_on_load_unexpected:
                unexpected_keys = [
                    k for k in unexpected_keys if re.search(pat, k) is None
                ]

        # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
        if _fast_init:
            uninitialized_modules = retrieve_modules_from_names(
                model,
                missing_keys,
                prefix=prefix,
                add_prefix=add_prefix_to_model,
                remove_prefix=remove_prefix_from_model)
            for module in uninitialized_modules:
                model._init_weights(module)

        # Make sure we are able to load head correctly by revise state-dict
        missing_head_keys_by_head = dict()
        if len(head_prefix_keys) > 0:
            for head_prefix in head_prefix_keys:
                if not pretrained_has_prefix_head[head_prefix]:
                    state_dict, missing_head_keys, missing_keys = _add_head_prefix_to_state_dict(
                        state_dict, head_prefix, head_prefix_keys[head_prefix],
                        missing_keys)
                    missing_head_keys_by_head[head_prefix] = missing_head_keys

        # Make sure we are able to load base models as well as derived models (with heads)
        start_prefix = ''
        model_to_load = model
        heads_to_load = dict()
        if len(model.base_model_prefix) > 0 and not hasattr(
                model,
                model.base_model_prefix) and pretrained_has_prefix_module:
            start_prefix = model.base_model_prefix + '.'
        if len(model.base_model_prefix) > 0 and hasattr(
                model,
                model.base_model_prefix) and not pretrained_has_prefix_module:
            model_to_load = getattr(model, model.base_model_prefix)
            for head_prefix in prefix_heads:
                heads_to_load[head_prefix] = getattr(model, head_prefix)
            if any(key in expected_keys_not_base_model_prefixed
                   for key in keys_from_pretrained):
                raise ValueError(
                    'The state dictionary of the model you are trying to load is corrupted. Are you sure it was '
                    'properly saved?')

        # Whole checkpoint
        mismatched_keys = _find_mismatched_keys(
            state_dict,
            model_state_dict,
            keys_from_pretrained,
            prefix,
            add_prefix_to_model,
            remove_prefix_from_model,
            ignore_mismatched_sizes,
        )
        error_msgs = _load_state_dict_into_model(model_to_load, state_dict,
                                                 start_prefix, load_state_fn)

        if len(heads_to_load) > 0:
            for head in heads_to_load:
                local_error_msgs = _load_state_dict_into_model(
                    heads_to_load[head], state_dict, head + '.', load_state_fn)
                error_msgs.extend(local_error_msgs)

        if len(error_msgs) > 0:
            error_msg = '\n\t'.join(error_msgs)
            raise RuntimeError(
                f'Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}'
            )

        if len(unexpected_keys) > 0:
            logger.warning(
                f'Some weights of the model checkpoint were not used when'
                f' initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are'
                f' initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or'
                ' with another architecture (e.g. initializing a BertForTokenClassification model from a'
                ' BertForPreTraining model).\n- This IS NOT expected if you are initializing'
                f' {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical'
                ' (initializing a BertForTokenClassification model from a BertForTokenClassification model).'
            )
        elif len(unexpected_head_keys) > 0:
            logger.warning(
                f'Some weights of the model checkpoint were not used when'
                f' initializing {model.__class__.__name__}: {unexpected_head_keys}\n- This IS Not expected if you are'
                f' initializing {model.__class__.__name__} from the checkpoint of a model with a same task while the'
                ' structure is different (e.g. initializing a BertForTokenClassification model from a'
                ' BertForTokenClassification model).')
        else:
            logger.info(
                f'All model checkpoint weights were used when initializing {model.__class__.__name__}.\n'
            )
        if len(missing_keys) > 0:
            logger.warning(
                f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
                f' and are newly initialized: {missing_keys}\nYou should probably'
                ' TRAIN this model on a down-stream task to be able to use it for predictions and inference.'
            )
        elif len(mismatched_keys) == 0:
            logger.info(
                f'All the weights of {model.__class__.__name__} were initialized from the model checkpoint '
                f'If your task is similar to the task the model of the checkpoint'
                f' was trained on, you can already use {model.__class__.__name__} for predictions without further'
                ' training.')
        if len(mismatched_keys) > 0:
            mismatched_warning = '\n'.join([
                f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated'
                for key, shape1, shape2 in mismatched_keys
            ])
            logger.warning(
                f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
                f' and are newly initialized because the shapes did not'
                f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able'
                ' to use it for predictions and inference.')

        return missing_keys, unexpected_keys, mismatched_keys, error_msgs

    def retrieve_modules_from_names(model,
                                    names,
                                    prefix=None,
                                    add_prefix=False,
                                    remove_prefix=False):
        module_keys = set(['.'.join(key.split('.')[:-1]) for key in names])

        # torch.nn.ParameterList is a special case where two parameter keywords
        # are appended to the module name, *e.g.* bert.special_embeddings.0
        module_keys = module_keys.union(
            set([
                '.'.join(key.split('.')[:-2]) for key in names
                if key[-1].isdigit()
            ]))

        retrieved_modules = []
        # retrieve all modules that has at least one missing weight name
        for name, module in model.named_modules():
            if remove_prefix:
                name = '.'.join(
                    name.split('.')[1:]) if name.startswith(prefix) else name
            elif add_prefix:
                name = '.'.join([prefix, name]) if len(name) > 0 else prefix

            if name in module_keys:
                retrieved_modules.append(module)

        return retrieved_modules

    def _tie_or_clone_weights(output_embeddings,
                              input_embeddings,
                              torchscript=False):
        if torchscript:
            output_embeddings.weight = nn.Parameter(
                input_embeddings.weight.clone())
        else:
            output_embeddings.weight = input_embeddings.weight

        if getattr(output_embeddings, 'bias', None) is not None:
            output_embeddings.bias.data = nn.functional.pad(
                output_embeddings.bias.data,
                (
                    0,
                    output_embeddings.weight.shape[0]
                    - output_embeddings.bias.shape[0],
                ),
                'constant',
                0,
            )

        if hasattr(output_embeddings, 'out_features') and hasattr(
                input_embeddings, 'num_embeddings'):
            output_embeddings.out_features = input_embeddings.num_embeddings

    def tie_weights(model, tie_word_embeddings=False):
        if tie_word_embeddings:
            output_embeddings = model.head.get_output_embeddings()
            if output_embeddings is not None:
                input_embeddings = model.encoder.get_input_embeddings()
                _tie_or_clone_weights(output_embeddings, input_embeddings)

    # TODO Sharded ckpt
    ckpt_file = os.path.join(model_local_dir, ModelFile.TORCH_MODEL_BIN_FILE)
    state_dict = torch.load(ckpt_file, map_location='cpu', weights_only=True)
    if default_dtype is not None:
        torch.set_default_dtype(default_dtype)

    missing_keys, unexpected_keys, mismatched_keys, error_msgs = _load_checkpoint(
        model_to_load,
        state_dict,
        load_state_fn=load_state_fn,
        ignore_mismatched_sizes=True,
        _fast_init=True,
    )

    if getattr(kwargs.get('head'), 'tie_word_embeddings', False):
        tie_weights(model_to_load, kwargs.get('head').tie_word_embeddings)

    return {
        'model': model_to_load,
        'missing_keys': missing_keys,
        'unexpected_keys': unexpected_keys,
        'mismatched_keys': mismatched_keys,
        'error_msgs': error_msgs,
    }


def save_configuration(target_folder, config: Dict):
    if isinstance(config, Config):
        config = config.to_dict()
    if ConfigFields.pipeline not in config:
        config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]}
    cfg_str = json.dumps(config, indent=4, cls=JSONIteratorEncoder)
    config_file = os.path.join(target_folder, ModelFile.CONFIGURATION)
    storage.write(cfg_str.encode(), config_file)


def save_pretrained(model,
                    target_folder: Union[str, os.PathLike],
                    save_checkpoint_name: str = None,
                    save_function: Callable = None,
                    **kwargs):
    """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded

    Args:
        model (Model): Model whose params are to be saved.

        target_folder (Union[str, os.PathLike]):
        Directory to which to save. Will be created if it doesn't exist.

        save_checkpoint_name (str):
        The checkpoint name to be saved in the target_folder

        save_function (Callable):
        The function to use to save the state dictionary.
    """

    if save_function is None or not isinstance(save_function, Callable):
        raise Exception('A valid save function must be passed in')

    if target_folder is None or os.path.isfile(target_folder):
        raise ValueError(
            f'Provided path ({target_folder}) should be a directory, not a file'
        )

    if save_checkpoint_name is None:
        raise Exception(
            'At least pass in one checkpoint name for saving method')

    # Single ckpt path, sharded ckpt logic will be added later
    output_ckpt_path = os.path.join(target_folder, save_checkpoint_name)

    # Save the files to be copied to the save directory, ignore the original ckpts and configuration
    origin_file_to_be_ignored = [save_checkpoint_name]
    ignore_file_set = set(origin_file_to_be_ignored)
    ignore_file_set.add(ModelFile.CONFIGURATION)
    ignore_file_set.add('*.safetensors')
    ignore_file_set.add('.*')
    if hasattr(model,
               'model_dir') and model.model_dir is not None and is_master():
        if sys.version_info.minor >= 8:
            copytree_func = copytree
        else:  # == 7
            copytree_func = copytree_py37
        copytree_func(
            model.model_dir,
            target_folder,
            ignore=ignore_patterns(*ignore_file_set),
            dirs_exist_ok=True)

    # Save the ckpt to the save directory
    try:
        save_function(model, output_ckpt_path, **kwargs)
    except Exception as e:
        raise Exception(
            f'During saving checkpoints, the error of "{type(e).__name__} '
            f'with msg {e} thrown')
