# Copyright (c) Alibaba, Inc. and its affiliates.

import contextlib
import hashlib
import os
import pickle
import random
import re
import shutil
import tempfile
from collections import OrderedDict
from collections.abc import Mapping
from pathlib import Path
from types import FunctionType
from typing import Any, Dict, Union

import json
import numpy as np
import torch
import torch.optim
from torch import nn

from .test_utils import compare_arguments_nested


class RegressTool:
    """This class is used to stop inference/training results from changing by some unaware affections by unittests.

    Firstly, run a baseline test to create a result file, then changes can be observed between
    the latest version and the baseline file.
    """

    def __init__(self,
                 baseline: bool = None,
                 store_func: FunctionType = None,
                 load_func: FunctionType = None):
        """A func to store the baseline file and a func to load the baseline file.
        """
        self.baseline = baseline
        self.store_func = store_func
        self.load_func = load_func
        print(f'Current working dir is: {Path.cwd()}')

    def store(self, local, remote):
        if self.store_func is not None:
            self.store_func(local, remote)
        else:
            path = os.path.abspath(
                os.path.join(Path.cwd(), 'data', 'test', 'regression'))
            os.makedirs(path, exist_ok=True)
            shutil.copy(local, os.path.join(path, remote))

    def load(self, local, remote):
        if self.load_func is not None:
            self.load_func(local, remote)
        else:
            path = os.path.abspath(
                os.path.join(Path.cwd(), 'data', 'test', 'regression'))
            baseline = os.path.join(path, remote)
            if not os.path.exists(baseline):
                raise ValueError(f'base line file {baseline} not exist')
            print(
                f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}'
            )
            if os.path.exists(local):
                os.remove(local)
            os.symlink(baseline, local, target_is_directory=False)

    @contextlib.contextmanager
    def monitor_module_single_forward(self,
                                      module: nn.Module,
                                      file_name: str,
                                      compare_fn=None,
                                      compare_model_output=True,
                                      **kwargs):
        """Monitor a pytorch module in a single forward.

        Args:
            module: A torch module
            file_name: The file_name to store or load file
            compare_fn: A custom fn used to compare the results manually.
            compare_model_output: Only compare the input module's output, skip all other tensors

        >>> def compare_fn(v1, v2, key, type):
        >>>     return None

        v1 is the baseline value
        v2 is the value of current version
        key is the key of submodules
        type is in one of 'input', 'output'

            kwargs:
            atol: The absolute gap between two np arrays.
            rtol: The relative gap between two np arrays.
        """
        baseline = os.getenv('REGRESSION_BASELINE')
        if baseline is None or self.baseline is None:
            yield
            return

        baseline = self.baseline
        io_json = {}
        absolute_path = f'./{file_name}.bin'
        if not isinstance(module, nn.Module):
            assert hasattr(module, 'model')
            module = module.model

        hack_forward(module, file_name, io_json)
        intercept_module(module, io_json)
        yield
        hack_forward(module, None, None, restore=True)
        intercept_module(module, None, restore=True)
        if baseline:
            with open(absolute_path, 'wb') as f:
                pickle.dump(io_json, f)
            self.store(absolute_path, f'{file_name}.bin')
            os.remove(absolute_path)
        else:
            name = os.path.basename(absolute_path)
            baseline = os.path.join(tempfile.gettempdir(), name)
            self.load(baseline, name)
            with open(baseline, 'rb') as f:
                base = pickle.load(f)

            class SafeNumpyEncoder(json.JSONEncoder):

                def parse_default(self, obj):
                    if isinstance(obj, np.ndarray):
                        return obj.tolist()

                    if isinstance(obj, np.floating):
                        return float(obj)

                    if isinstance(obj, np.integer):
                        return int(obj)

                    return json.JSONEncoder.default(self, obj)

                def default(self, obj):
                    try:
                        return self.default(obj)
                    except Exception:
                        print(
                            f'Type {obj.__class__} cannot be serialized and printed'
                        )
                        return None

            if compare_model_output:
                print(
                    'Ignore inner modules, only the output of the model will be verified.'
                )
                base = {
                    key: value
                    for key, value in base.items() if key == file_name
                }
                for key, value in base.items():
                    value['input'] = {'args': None, 'kwargs': None}
                io_json = {
                    key: value
                    for key, value in io_json.items() if key == file_name
                }
                for key, value in io_json.items():
                    value['input'] = {'args': None, 'kwargs': None}

            print(f'baseline: {json.dumps(base, cls=SafeNumpyEncoder)}')
            print(f'latest  : {json.dumps(io_json, cls=SafeNumpyEncoder)}')
            if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
                raise ValueError('Result not match!')

    @contextlib.contextmanager
    def monitor_module_train(self,
                             trainer: Union[Dict, Any],
                             file_name,
                             level='config',
                             compare_fn=None,
                             ignore_keys=None,
                             compare_random=True,
                             reset_dropout=True,
                             lazy_stop_callback=None,
                             **kwargs):
        """Monitor a pytorch module's backward data and cfg data within a step of the optimizer.

        This is usually useful when you try to change some dangerous code
        which has the risk of affecting the training loop.

        Args:
            trainer: A dict or an object contains the model/optimizer/lr_scheduler
            file_name: The file_name to store or load file
            level: The regression level.
            'strict' for matching every single tensor.
                     Please make sure the parameters of head are fixed
                     and the drop-out rate is zero.
            'config' for matching the initial config, like cfg file, optimizer param_groups,
                     lr_scheduler params and the random seed.
            'metric' for compare the best metrics in the evaluation loop.
            compare_fn: A custom fn used to compare the results manually.
            ignore_keys: The keys to ignore of the named_parameters.
            compare_random: If to compare random setttings, default True.
            reset_dropout: Reset all dropout modules to 0.0.
            lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called.
            kwargs:
            atol: The absolute gap between two np arrays.
            rtol: The relative gap between two np arrays.

        >>> def compare_fn(v1, v2, key, type):
        >>>     return None

        v1 is the baseline value
        v2 is the value of current version
        key is the key of modules/parameters
        type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state'
        """
        baseline = os.getenv('REGRESSION_BASELINE')
        if baseline is None or self.baseline is None:
            yield
            return

        baseline = self.baseline

        io_json = {}
        bw_json = {}
        absolute_path = f'./{file_name}.bin'

        if level == 'strict':
            print(
                "[Important] The level of regression is 'strict', please make sure your model's parameters are "
                'fixed and all drop-out rates have been set to zero.')

        assert hasattr(
            trainer, 'model') or 'model' in trainer, 'model must be in trainer'
        module = trainer['model'] if isinstance(trainer,
                                                dict) else trainer.model
        if not isinstance(module, nn.Module):
            assert hasattr(module, 'model')
            module = module.model

        assert hasattr(
            trainer, 'optimizer'
        ) or 'optimizer' in trainer, 'optimizer must be in trainer'
        assert hasattr(
            trainer, 'lr_scheduler'
        ) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer'
        optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance(
            trainer, dict) else trainer.optimizer
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \
            else trainer.lr_scheduler
        torch_state = numpify_tensor_nested(torch.get_rng_state())
        np_state = np.random.get_state()
        random_seed = random.getstate()
        seed = trainer._seed if hasattr(
            trainer,
            '_seed') else trainer.seed if hasattr(trainer, 'seed') else None

        if reset_dropout:
            with torch.no_grad():

                def reinit_dropout(_module):
                    for name, submodule in _module.named_children():
                        if isinstance(submodule, torch.nn.Dropout):
                            setattr(_module, name, torch.nn.Dropout(0.))
                        else:
                            reinit_dropout(submodule)

                reinit_dropout(module)

        if level == 'strict':
            hack_forward(module, file_name, io_json)
            intercept_module(module, io_json)
        hack_backward(
            module, optimizer, bw_json, lazy_stop_callback=lazy_stop_callback)
        yield
        hack_backward(module, optimizer, None, restore=True)
        if level == 'strict':
            hack_forward(module, None, None, restore=True)
            intercept_module(module, None, restore=True)

        optimizer_dict = optimizer.state_dict()
        optimizer_dict.pop('state', None)
        summary = {
            'forward': io_json,
            'backward': bw_json,
            'optimizer': {
                'type': optimizer.__class__.__name__,
                'defaults': optimizer.defaults,
                'state_dict': optimizer_dict
            },
            'lr_scheduler': {
                'type': lr_scheduler.__class__.__name__,
                'state_dict': lr_scheduler.state_dict()
            },
            'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None,
            'state': {
                'torch_state': torch_state,
                'np_state': np_state,
                'random_seed': random_seed,
                'seed': seed,
            }
        }

        if baseline:
            with open(absolute_path, 'wb') as f:
                pickle.dump(summary, f)
            self.store(absolute_path, f'{file_name}.bin')
            os.remove(absolute_path)
        else:
            name = os.path.basename(absolute_path)
            baseline = os.path.join(tempfile.gettempdir(), name)
            self.load(baseline, name)
            with open(baseline, 'rb') as f:
                baseline_json = pickle.load(f)

            if level == 'strict' and not compare_io_and_print(
                    baseline_json['forward'], io_json, compare_fn, **kwargs):
                raise RuntimeError('Forward not match!')
            if not compare_backward_and_print(
                    baseline_json['backward'],
                    bw_json,
                    compare_fn=compare_fn,
                    ignore_keys=ignore_keys,
                    level=level,
                    **kwargs):
                raise RuntimeError('Backward not match!')
            cfg_opt1 = {
                'optimizer': baseline_json['optimizer'],
                'lr_scheduler': baseline_json['lr_scheduler'],
                'cfg': baseline_json['cfg'],
                'state': None if not compare_random else baseline_json['state']
            }
            cfg_opt2 = {
                'optimizer': summary['optimizer'],
                'lr_scheduler': summary['lr_scheduler'],
                'cfg': summary['cfg'],
                'state': None if not compare_random else summary['state']
            }
            if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn,
                                              **kwargs):
                raise RuntimeError('Cfg or optimizers not match!')


class MsRegressTool(RegressTool):

    class EarlyStopError(Exception):
        pass

    @contextlib.contextmanager
    def monitor_ms_train(self,
                         trainer,
                         file_name,
                         level='config',
                         compare_fn=None,
                         ignore_keys=None,
                         compare_random=True,
                         lazy_stop_callback=None,
                         **kwargs):

        if lazy_stop_callback is None:

            def lazy_stop_callback():

                class EarlyStopHook:
                    PRIORITY = 90

                    def before_run(self, trainer):
                        pass

                    def after_run(self, trainer):
                        pass

                    def before_epoch(self, trainer):
                        pass

                    def after_epoch(self, trainer):
                        pass

                    def before_iter(self, trainer):
                        pass

                    def before_train_epoch(self, trainer):
                        self.before_epoch(trainer)

                    def before_val_epoch(self, trainer):
                        self.before_epoch(trainer)

                    def after_train_epoch(self, trainer):
                        self.after_epoch(trainer)

                    def after_val_epoch(self, trainer):
                        self.after_epoch(trainer)

                    def before_train_iter(self, trainer):
                        self.before_iter(trainer)

                    def before_val_iter(self, trainer):
                        self.before_iter(trainer)

                    def after_train_iter(self, trainer):
                        self.after_iter(trainer)

                    def after_val_iter(self, trainer):
                        self.after_iter(trainer)

                    def every_n_epochs(self, trainer, n):
                        return (trainer.epoch + 1) % n == 0 if n > 0 else False

                    def every_n_inner_iters(self, runner, n):
                        return (runner.inner_iter
                                + 1) % n == 0 if n > 0 else False

                    def every_n_iters(self, trainer, n):
                        return (trainer.iter + 1) % n == 0 if n > 0 else False

                    def end_of_epoch(self, trainer):
                        return trainer.inner_iter + 1 == trainer.iters_per_epoch

                    def is_last_epoch(self, trainer):
                        return trainer.epoch + 1 == trainer.max_epochs

                    def is_last_iter(self, trainer):
                        return trainer.iter + 1 == trainer.max_iters

                    def get_triggered_stages(self):
                        return []

                    def state_dict(self):
                        return {}

                    def load_state_dict(self, state_dict):
                        pass

                    def after_iter(self, trainer):
                        raise MsRegressTool.EarlyStopError('Test finished.')

                trainer.register_hook(EarlyStopHook())

        def _train_loop(trainer, *args_train, **kwargs_train):
            with self.monitor_module_train(
                    trainer,
                    file_name,
                    level,
                    compare_fn=compare_fn,
                    ignore_keys=ignore_keys,
                    compare_random=compare_random,
                    lazy_stop_callback=lazy_stop_callback,
                    **kwargs):
                try:
                    return trainer.train_loop_origin(*args_train,
                                                     **kwargs_train)
                except MsRegressTool.EarlyStopError:
                    pass

        trainer.train_loop_origin, trainer.train_loop = \
            trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer)
        yield


def compare_module(module1: nn.Module, module2: nn.Module):
    for p1, p2 in zip(module1.parameters(), module2.parameters()):
        if p1.data.ne(p2.data).sum() > 0:
            return False
    return True


def numpify_tensor_nested(tensors, reduction=None, clip_value=10000):
    try:
        from modelscope.outputs import ModelOutputBase
    except ImportError:
        ModelOutputBase = dict
    "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
    if isinstance(tensors, (Mapping, ModelOutputBase)):
        return OrderedDict({
            k: numpify_tensor_nested(t, reduction, clip_value)
            for k, t in tensors.items()
        })
    if isinstance(tensors, list):
        return list(
            numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
    if isinstance(tensors, tuple):
        return tuple(
            numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
    if isinstance(tensors, torch.Tensor):
        t: np.ndarray = tensors.cpu().numpy()
        if clip_value is not None:
            t = np.where(t > clip_value, clip_value, t)
            t = np.where(t < -clip_value, -clip_value, t)
        if reduction == 'sum':
            return t.sum(dtype=float)
        elif reduction == 'mean':
            return t.mean(dtype=float)
        return t
    return tensors


def detach_tensor_nested(tensors):
    try:
        from modelscope.outputs import ModelOutputBase
    except ImportError:
        ModelOutputBase = dict
    "Detach `tensors` (even if it's a nested list/tuple of tensors)."
    if isinstance(tensors, (Mapping, ModelOutputBase)):
        return OrderedDict(
            {k: detach_tensor_nested(t)
             for k, t in tensors.items()})
    if isinstance(tensors, list):
        return list(detach_tensor_nested(t) for t in tensors)
    if isinstance(tensors, tuple):
        return tuple(detach_tensor_nested(t) for t in tensors)
    if isinstance(tensors, torch.Tensor):
        return tensors.detach()
    return tensors


def hack_forward(module: nn.Module,
                 name,
                 io_json,
                 restore=False,
                 keep_tensors=False):

    def _forward(self, *args, **kwargs):
        ret = self.forward_origin(*args, **kwargs)
        if keep_tensors:
            args = numpify_tensor_nested(detach_tensor_nested(args))
            kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs))
            output = numpify_tensor_nested(detach_tensor_nested(ret))
        else:
            args = {
                'sum':
                numpify_tensor_nested(
                    detach_tensor_nested(args), reduction='sum'),
                'mean':
                numpify_tensor_nested(
                    detach_tensor_nested(args), reduction='mean'),
            }
            kwargs = {
                'sum':
                numpify_tensor_nested(
                    detach_tensor_nested(kwargs), reduction='sum'),
                'mean':
                numpify_tensor_nested(
                    detach_tensor_nested(kwargs), reduction='mean'),
            }
            output = {
                'sum':
                numpify_tensor_nested(
                    detach_tensor_nested(ret), reduction='sum'),
                'mean':
                numpify_tensor_nested(
                    detach_tensor_nested(ret), reduction='mean'),
            }

        io_json[name] = {
            'input': {
                'args': args,
                'kwargs': kwargs,
            },
            'output': output,
        }
        return ret

    if not restore and not hasattr(module, 'forward_origin'):
        module.forward_origin, module.forward = module.forward, type(
            module.forward)(_forward, module)
    if restore and hasattr(module, 'forward_origin'):
        module.forward = module.forward_origin
        del module.forward_origin


def hack_backward(module: nn.Module,
                  optimizer,
                  io_json,
                  restore=False,
                  lazy_stop_callback=None):

    def _step(self, *args, **kwargs):
        for name, param in module.named_parameters():
            io_json[name] = {
                'data': {
                    'sum':
                    numpify_tensor_nested(
                        detach_tensor_nested(param.data), reduction='sum'),
                    'mean':
                    numpify_tensor_nested(
                        detach_tensor_nested(param.data), reduction='mean'),
                },
                'grad': {
                    'sum':
                    numpify_tensor_nested(
                        detach_tensor_nested(param.grad), reduction='sum'),
                    'mean':
                    numpify_tensor_nested(
                        detach_tensor_nested(param.grad), reduction='mean'),
                }
            }
        ret = self.step_origin(*args, **kwargs)
        for name, param in module.named_parameters():
            io_json[name]['data_after'] = {
                'sum':
                numpify_tensor_nested(
                    detach_tensor_nested(param.data), reduction='sum'),
                'mean':
                numpify_tensor_nested(
                    detach_tensor_nested(param.data), reduction='mean'),
            }
        if lazy_stop_callback is not None:
            lazy_stop_callback()
        return ret

    if not restore and not hasattr(optimizer, 'step_origin'):
        optimizer.step_origin, optimizer.step = optimizer.step, type(
            optimizer.state_dict)(_step, optimizer)
    if restore and hasattr(optimizer, 'step_origin'):
        optimizer.step = optimizer.step_origin
        del optimizer.step_origin


def intercept_module(module: nn.Module,
                     io_json,
                     parent_name=None,
                     restore=False):
    for name, module in module.named_children():
        full_name = parent_name + '.' + name if parent_name is not None else name
        hack_forward(module, full_name, io_json, restore)
        intercept_module(module, io_json, full_name, restore)


def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
    if compare_fn is None:

        def compare_fn(*args, **kwargs):
            return None

    keys1 = set(baseline_json.keys())
    keys2 = set(io_json.keys())
    added = keys1 - keys2
    removed = keys2 - keys1
    print(f'unmatched keys: {added}, {removed}')
    shared_keys = keys1.intersection(keys2)
    match = True
    for key in shared_keys:
        v1 = baseline_json[key]
        v2 = io_json[key]

        v1input = numpify_tensor_nested(v1['input'])
        v2input = numpify_tensor_nested(v2['input'])
        res = compare_fn(v1input, v2input, key, 'input')
        if res is not None:
            print(
                f'input of {key} compared with user compare_fn with result:{res}\n'
            )
            match = match and res
        else:
            match = compare_arguments_nested(
                f'unmatched module {key} input args', v1input['args'],
                v2input['args'], **kwargs) and match
            match = compare_arguments_nested(
                f'unmatched module {key} input kwargs', v1input['kwargs'],
                v2input['kwargs'], **kwargs) and match
        v1output = numpify_tensor_nested(v1['output'])
        v2output = numpify_tensor_nested(v2['output'])
        res = compare_fn(v1output, v2output, key, 'output')
        if res is not None:
            print(
                f'output of {key} compared with user compare_fn with result:{res}\n'
            )
            match = match and res
        else:
            match = compare_arguments_nested(
                f'unmatched module {key} outputs',
                arg1=v1output,
                arg2=v2output,
                **kwargs) and match
    return match


def compare_backward_and_print(baseline_json,
                               bw_json,
                               level,
                               ignore_keys=None,
                               compare_fn=None,
                               **kwargs):
    if compare_fn is None:

        def compare_fn(*args, **kwargs):
            return None

    keys1 = set(baseline_json.keys())
    keys2 = set(bw_json.keys())
    added = keys1 - keys2
    removed = keys2 - keys1
    print(f'unmatched backward keys: {added}, {removed}')
    shared_keys = keys1.intersection(keys2)
    match = True
    for key in shared_keys:
        if ignore_keys is not None and key in ignore_keys:
            continue

        res = compare_fn(baseline_json[key], bw_json[key], key, 'backward')
        if res is not None:
            print(f'backward data of {key} compared with '
                  f'user compare_fn with result:{res}\n')
            match = match and res
        else:
            data1, grad1, data_after1 = baseline_json[key][
                'data'], baseline_json[key]['grad'], baseline_json[key][
                    'data_after']
            data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
                'grad'], bw_json[key]['data_after']
            match = compare_arguments_nested(
                f'unmatched module {key} tensor data',
                arg1=data1,
                arg2=data2,
                **kwargs) and match
            if level == 'strict':
                match = compare_arguments_nested(
                    f'unmatched module {key} grad data',
                    arg1=grad1,
                    arg2=grad2,
                    **kwargs) and match
                match = compare_arguments_nested(
                    f'unmatched module {key} data after step', data_after1,
                    data_after2, **kwargs) and match
    return match


def compare_cfg_and_optimizers(baseline_json,
                               cfg_json,
                               compare_fn=None,
                               **kwargs):
    if compare_fn is None:

        def compare_fn(*args, **kwargs):
            return None

    optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[
        'optimizer'], baseline_json['lr_scheduler'], baseline_json[
            'cfg'], baseline_json['state']
    optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[
        'lr_scheduler'], cfg_json['cfg'], baseline_json['state']

    match = True
    res = compare_fn(optimizer1, optimizer2, None, 'optimizer')
    if res is not None:
        print(f'optimizer compared with user compare_fn with result:{res}\n')
        match = match and res
    else:
        if optimizer1['type'] != optimizer2['type']:
            print(
                f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
            )
        match = compare_arguments_nested(
            'unmatched optimizer defaults', optimizer1['defaults'],
            optimizer2['defaults'], **kwargs) and match
        match = compare_arguments_nested(
            'unmatched optimizer state_dict', optimizer1['state_dict'],
            optimizer2['state_dict'], **kwargs) and match

    res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
    if res is not None:
        print(
            f'lr_scheduler compared with user compare_fn with result:{res}\n')
        match = match and res
    else:
        if lr_scheduler1['type'] != lr_scheduler2['type']:
            print(
                f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
            )
        match = compare_arguments_nested(
            'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'],
            lr_scheduler2['state_dict'], **kwargs) and match

    res = compare_fn(cfg1, cfg2, None, 'cfg')
    if res is not None:
        print(f'cfg compared with user compare_fn with result:{res}\n')
        match = match and res
    else:
        match = compare_arguments_nested(
            'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match

    res = compare_fn(state1, state2, None, 'state')
    if res is not None:
        print(
            f'random state compared with user compare_fn with result:{res}\n')
        match = match and res
    else:
        match = compare_arguments_nested('unmatched random state', state1,
                                         state2, **kwargs) and match

    return match


class IgnoreKeyFn:

    def __init__(self, keys):
        if isinstance(keys, str):
            keys = [keys]
        self.keys = keys if isinstance(keys, list) else []

    def __call__(self, v1output, v2output, key, type):
        for _key in self.keys:
            pattern = re.compile(_key)
            if key is not None and pattern.fullmatch(key):
                return True
        return None
