# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from .inference import PaddlePredictorOption, create_predictor
from .modules import (
    build_dataset_checker,
    build_evaluator,
    build_exportor,
    build_trainer,
)


# TODO(gaotingquan): support _ModelBasedConfig
def create_model(model_name, model_dir=None, *args, **kwargs):
    return _ModelBasedInference(
        model_name=model_name, model_dir=model_dir, *args, **kwargs
    )


class _BaseModel:
    def check_dataset(self, *args, **kwargs):
        raise Exception("check_dataset is not supported!")

    def train(self, *args, **kwargs):
        raise Exception("train is not supported!")

    def evaluate(self, *args, **kwargs):
        raise Exception("evaluate is not supported!")

    def export(self, *args, **kwargs):
        raise Exception("export is not supported!")

    def predict(self, *args, **kwargs):
        raise Exception("predict is not supported!")

    def set_predict(self, *args, **kwargs):
        raise Exception("set_predict is not supported!")

    def __call__(self, *args, **kwargs):
        yield from self.predict(*args, **kwargs)


class _ModelBasedInference(_BaseModel):
    def __init__(self, *args, **kwargs):
        self._predictor = create_predictor(*args, **kwargs)

    def predict(self, *args, **kwargs):
        yield from self._predictor(*args, **kwargs)

    def set_predictor(self, **kwargs):
        self._predictor.set_predictor(**kwargs)

    def __getattr__(self, name):
        if hasattr(self._predictor, name):
            return getattr(self._predictor, name)
        raise AttributeError(
            f"'{self.__class__.__name__}' object has no attribute '{name}'"
        )


class _ModelBasedConfig(_BaseModel):
    def __init__(self, config=None, *args, **kwargs):
        super().__init__()
        self._config = config
        self._model_name = config.Global.model

    def _build_predictor(self):
        predict_kwargs = deepcopy(self._config.Predict)

        model_dir = predict_kwargs.pop("model_dir", None)
        device = self._config.Global.get("device", None)

        UNSET = object()
        kernel_option = predict_kwargs.pop("kernel_option", UNSET)
        use_hpip = predict_kwargs.pop("use_hpip", UNSET)
        hpi_config = predict_kwargs.pop("hpi_config", UNSET)
        genai_config = predict_kwargs.pop("genai_config", UNSET)

        create_predictor_kwargs = {}
        if kernel_option is not UNSET:
            create_predictor_kwargs["pp_option"] = PaddlePredictorOption(
                **kernel_option
            )
        if use_hpip is not UNSET:
            create_predictor_kwargs["use_hpip"] = use_hpip
        else:
            create_predictor_kwargs["use_hpip"] = False
        if hpi_config is not UNSET:
            create_predictor_kwargs["hpi_config"] = hpi_config
        if genai_config is not UNSET:
            create_predictor_kwargs["genai_config"] = genai_config

        predictor = create_predictor(
            self._model_name,
            model_dir=model_dir,
            device=device,
            **create_predictor_kwargs,
        )
        assert "input" in predict_kwargs
        return predict_kwargs, predictor

    def check_dataset(self):
        dataset_checker = build_dataset_checker(self._config)
        return dataset_checker.check()

    def train(self):
        trainer = build_trainer(self._config)
        trainer.train()

    def evaluate(self):
        evaluator = build_evaluator(self._config)
        return evaluator.evaluate()

    def export(self):
        exportor = build_exportor(self._config)
        return exportor.export()

    def predict(self):
        predict_kwargs, predictor = self._build_predictor()
        yield from predictor(**predict_kwargs)
