# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import time
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional

from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.config import Config
from modelscope.utils.constant import Invoke, ThirdParty
from .utils.log_buffer import LogBuffer


class BaseTrainer(ABC):
    """ Base class for trainer which can not be instantiated.

    BaseTrainer defines necessary interface
    and provide default implementation for basic initialization
    such as parsing config file and parsing commandline args.
    """

    def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None):
        """ Trainer basic init, should be called in derived class

        Args:
            cfg_file: Path to configuration file.
            arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`.
        """
        self.cfg = Config.from_file(cfg_file)
        if arg_parse_fn:
            self.args = self.cfg.to_args(arg_parse_fn)
        else:
            self.args = None
        self.log_buffer = LogBuffer()
        self.visualization_buffer = LogBuffer()
        self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())

    def get_or_download_model_dir(self,
                                  model,
                                  model_revision=None,
                                  third_party=None):
        """ Get local model directory or download model if necessary.

        Args:
            model (str): model id or path to local model directory.
            model_revision  (str, optional): model version number.
            third_party (str, optional): in which third party library
                this function is called.
        """
        if os.path.exists(model):
            model_cache_dir = model if os.path.isdir(
                model) else os.path.dirname(model)
            check_local_model_is_latest(
                model_cache_dir,
                user_agent={
                    Invoke.KEY: Invoke.LOCAL_TRAINER,
                    ThirdParty.KEY: third_party
                })
        else:
            model_cache_dir = snapshot_download(
                model,
                revision=model_revision,
                user_agent={
                    Invoke.KEY: Invoke.TRAINER,
                    ThirdParty.KEY: third_party
                })
        return model_cache_dir

    @abstractmethod
    def train(self, *args, **kwargs):
        """ Train (and evaluate) process

        Train process should be implemented for specific task or
        model, related parameters have been initialized in
        ``BaseTrainer.__init__`` and should be used in this function
        """
        pass

    @abstractmethod
    def evaluate(self, checkpoint_path: str, *args,
                 **kwargs) -> Dict[str, float]:
        """ Evaluation process

        Evaluation process should be implemented for specific task or
        model, related parameters have been initialized in
        ``BaseTrainer.__init__`` and should be used in this function
        """
        pass


@TRAINERS.register_module(module_name='dummy')
class DummyTrainer(BaseTrainer):

    def __init__(self, cfg_file: str, *args, **kwargs):
        """ Dummy Trainer.

        Args:
            cfg_file: Path to configuration file.
        """
        super().__init__(cfg_file)

    def train(self, *args, **kwargs):
        """ Train (and evaluate) process

        Train process should be implemented for specific task or
        model, related parameters have been initialized in
        ``BaseTrainer.__init__`` and should be used in this function
        """
        cfg = self.cfg.train
        print(f'train cfg {cfg}')

    def evaluate(self,
                 checkpoint_path: str = None,
                 *args,
                 **kwargs) -> Dict[str, float]:
        """ Evaluation process

        Evaluation process should be implemented for specific task or
        model, related parameters have been initialized in
        ``BaseTrainer.__init__`` and should be used in this function
        """
        cfg = self.cfg.evaluation
        print(f'eval cfg {cfg}')
        print(f'checkpoint_path {checkpoint_path}')
