# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.hook import Hook
from modelscope.trainers.hooks.priority import Priority
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.torch_utils import is_master


@HOOKS.register_module(module_name=Hooks.SparsityHook)
class SparsityHook(Hook):

    PRIORITY = Priority.HIGHEST

    def __init__(self, pruning_method, config={}, save_dir=None):
        self.pruning_method = pruning_method
        self.save_dir = save_dir

        self.compress_module = config.get('compress_module', [])
        self.weight_rank = config.get('weight_rank', 8)
        self.weight_beta = config.get('weight_beta', 1)
        self.mask_rank = config.get('mask_rank', 8)
        self.mask_alpha1 = config.get('mask_alpha1', 1)
        self.mask_alpha2 = config.get('mask_alpha2', 1)

        self.step = 0
        self.total_step = 0
        self.frequency = config.get('frequency', 1)
        self.initial_warmup = config.get('initial_warmup', 0.1)
        self.final_warmup = config.get('final_warmup', 0.3)
        self.initial_sparsity = config.get('initial_sparsity', 0.0)
        self.final_sparsity = config.get('final_sparsity', 0.0)

    def before_run(self, trainer):
        import torch

        from .utils import SparseLinear, convert_sparse_network

        if self.save_dir is None:
            self.save_dir = trainer.work_dir

        if len(self.compress_module) == 0:
            convert_sparse_network(
                trainer.model,
                pruning_method=self.pruning_method,
                weight_rank=self.weight_rank,
                weight_beta=self.weight_beta,
                mask_rank=self.mask_rank,
                mask_alpha1=self.mask_alpha1,
                mask_alpha2=self.mask_alpha2,
                logger=trainer.logger,
            )
        else:
            for cm in self.compress_module:
                for name, module in trainer.model.named_modules():
                    if name != cm:
                        continue
                    convert_sparse_network(
                        module,
                        pruning_method=self.pruning_method,
                        weight_rank=self.weight_rank,
                        weight_beta=self.weight_beta,
                        mask_rank=self.mask_rank,
                        mask_alpha1=self.mask_alpha1,
                        mask_alpha2=self.mask_alpha2,
                        logger=trainer.logger,
                    )

        for i in range(len(trainer.optimizer.param_groups)):
            new_train_params = []
            for param in trainer.optimizer.param_groups[i]['params']:
                is_find = False
                for name, module in trainer.model.named_modules():
                    if isinstance(module, SparseLinear):
                        if torch.equal(param.half(),
                                       module.weight.data.half()):
                            is_find = True
                            break

                if not is_find:
                    new_train_params.append(param)

            trainer.optimizer.param_groups[i]['params'] = new_train_params

        new_params = []
        for name, module in trainer.model.named_modules():
            if isinstance(module, SparseLinear):
                new_params.extend(
                    [p for p in module.parameters() if p.requires_grad])

        trainer.optimizer.add_param_group({'params': new_params})

        self.total_step = trainer.iters_per_epoch * trainer._max_epochs

    def before_train_iter(self, trainer):
        from .utils import schedule_sparsity_ratio, update_network_sparsity

        cur_sparsity = schedule_sparsity_ratio(
            self.step,
            self.total_step,
            self.frequency,
            self.initial_warmup,
            self.final_warmup,
            self.initial_sparsity,
            self.final_sparsity,
        )

        update_network_sparsity(trainer.model, cur_sparsity)

        if is_master():
            trainer.logger.info(
                f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}'
            )

        self.step += 1

    def after_run(self, trainer):
        from .utils import generate_sparse_model

        generate_sparse_model(trainer.model, logger=trainer.logger)

        self._save_checkpoint(trainer)

    def _save_checkpoint(self, trainer):
        if is_master():
            trainer.logger.info('Saving checkpoint at final compress')
        cur_save_name = os.path.join(self.save_dir, 'compress_model.pth')
        save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
