# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import inspect
import os
import sys
from types import ModuleType

import filelock

from ..utils import logging
from ..utils.deps import class_requires_deps


def get_user_home() -> str:
    return os.path.expanduser("~")


def get_pprndr_home() -> str:
    return os.path.join(get_user_home(), ".pprndr")


def get_sub_home(directory: str) -> str:
    home = os.path.join(get_pprndr_home(), directory)
    os.makedirs(home, exist_ok=True)
    return home


TMP_HOME = get_sub_home("tmp")

custom_ops = {
    "voxelize": {
        "sources": ["voxel/voxelize_op.cc", "voxel/voxelize_op.cu"],
        "version": "0.1.0",
    },
    "iou3d_nms": {
        "sources": [
            "iou3d_nms/iou3d_cpu.cpp",
            "iou3d_nms/iou3d_nms_api.cpp",
            "iou3d_nms/iou3d_nms.cpp",
            "iou3d_nms/iou3d_nms_kernel.cu",
        ],
        "version": "0.1.0",
    },
}


class CustomOpNotFoundException(Exception):
    def __init__(self, op_name):
        self.op_name = op_name

    def __str__(self):
        return "Couldn't Found custom op {}".format(self.op_name)


class CustomOperatorPathFinder:
    def find_spec(self, fullname: str, path, target=None):
        if not fullname.startswith("paddlex.ops"):
            return None
        return importlib.machinery.ModuleSpec(
            name=fullname,
            loader=CustomOperatorPathLoader(),
            is_package=False,
        )


class CustomOperatorPathLoader:
    def load_module(self, fullname: str):
        modulename = fullname.split(".")[-1]

        if modulename not in custom_ops:
            raise CustomOpNotFoundException(modulename)

        if fullname not in sys.modules:
            try:
                sys.modules[fullname] = importlib.import_module(modulename)
            except ImportError:
                sys.modules[fullname] = PaddleXCustomOperatorModule(
                    modulename, fullname
                )
        return sys.modules[fullname]


class PaddleXCustomOperatorModule(ModuleType):
    def __init__(self, modulename: str, fullname: str):
        self.fullname = fullname
        self.modulename = modulename
        self.module = None
        super().__init__(modulename)

    def jit_build(self):
        from paddle.utils.cpp_extension import load as paddle_jit_load

        try:
            lockfile = "paddlex.ops.{}".format(self.modulename)
            lockfile = os.path.join(TMP_HOME, lockfile)
            file = inspect.getabsfile(sys.modules["paddlex.ops"])
            rootdir = os.path.split(file)[0]

            args = custom_ops[self.modulename].copy()
            sources = args.pop("sources")
            sources = [os.path.join(rootdir, file) for file in sources]

            args.pop("version")
            with filelock.FileLock(lockfile):
                return paddle_jit_load(name=self.modulename, sources=sources, **args)
        except:
            logging.error("{} built fail!".format(self.modulename))
            raise

    def _load_module(self):
        if self.module is None:
            try:
                self.module = importlib.import_module(self.modulename)
            except ImportError:
                logging.warning(
                    "No custom op {} found, try JIT build".format(self.modulename)
                )
                self.module = self.jit_build()
                logging.info("{} built success!".format(self.modulename))

            # refresh
            sys.modules[self.fullname] = self.module
        return self.module

    def __getattr__(self, attr: str):
        if attr in ["__path__", "__file__"]:
            return None

        if attr in ["__loader__", "__package__", "__name__", "__spec__"]:
            return super().__getattr__(attr)

        module = self._load_module()
        if not hasattr(module, attr):
            raise ImportError(
                "cannot import name '{}' from '{}' ({})".format(
                    attr, self.modulename, module.__file__
                )
            )
        return getattr(module, attr)


sys.meta_path.insert(0, CustomOperatorPathFinder())
