# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from ....utils.deps import pipeline_requires_extra
from ...utils.benchmark import benchmark
from ..pp_shitu_v2 import ShiTuV2Pipeline
from .result import FaceRecResult


@benchmark.time_methods
@pipeline_requires_extra("cv")
class FaceRecPipeline(ShiTuV2Pipeline):
    """Face Recognition Pipeline"""

    entities = "face_recognition"

    def get_rec_result(
        self, raw_img, det_res, indexer, rec_threshold, hamming_radius, topk
    ):
        if len(det_res["boxes"]) == 0:
            return {"label": [], "score": []}
        subs_of_img = list(self.crop_by_boxes(raw_img, det_res["boxes"]))
        img_list = [img["img"] for img in subs_of_img]
        all_rec_res = list(self.rec_model(img_list))
        all_rec_res = indexer(
            [rec_res["feature"] for rec_res in all_rec_res],
            score_thres=rec_threshold,
            hamming_radius=hamming_radius,
            topk=topk,
        )
        output = {"label": [], "score": []}
        for res in all_rec_res:
            output["label"].append(res["label"])
            output["score"].append(res["score"])
        return output

    def get_final_result(self, input_data, raw_img, det_res, rec_res):
        single_img_res = {"input_path": input_data, "input_img": raw_img, "boxes": []}
        for i, obj in enumerate(det_res["boxes"]):
            rec_scores = rec_res["score"][i]
            if isinstance(rec_scores, np.ndarray):
                rec_scores = rec_scores.tolist()
            labels = rec_res["label"][i]
            single_img_res["boxes"].append(
                {
                    "labels": labels,
                    "rec_scores": rec_scores,
                    "det_score": obj["score"],
                    "coordinate": obj["coordinate"],
                }
            )
        return FaceRecResult(single_img_res)
