# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union

import numpy as np

from .setting import BLOCK_LABEL_MAP, LINE_SETTINGS
from .utils import (
    caculate_euclidean_dist,
    calculate_projection_overlap_ratio,
    is_english_letter,
    is_non_breaking_punctuation,
    is_numeric,
)

__all__ = [
    "TextSpan",
    "TextLine",
    "LayoutBlock",
    "LayoutRegion",
]


class TextSpan(object):
    """Text span class"""

    def __init__(self, box, text, label):
        """
        Initialize a TextSpan object.

        Args:
            box (list): The bounding box of the text span.
            text (str): The text content of the text span.
            label (int): The label of the text span.
        """
        self.box = box
        self.text = text
        self.label = label

    def __str__(self) -> str:
        return f"{self.text}"

    def __repr__(self) -> str:
        return f"{self.text}"


class TextLine(object):
    """Text line class"""

    def __init__(self, spans: List[TextSpan] = [], direction="horizontal"):
        """
        Initialize a TextLine object.

        Args:
            spans (List[TextSpan]): A list of TextSpan objects. Defaults to [].
            direction (str): The direction of the text line. Defaults to "horizontal".
        """
        self.spans = spans
        self.direction = direction
        self.region_box = self.get_region_box()
        self.need_new_line = False

    @property
    def labels(self):
        return [span.label for span in self.spans]

    @property
    def boxes(self):
        return [span.box for span in self.spans]

    @property
    def height(self):
        start_idx = 1 if self.direction == "horizontal" else 0
        end_idx = 3 if self.direction == "horizontal" else 2
        return abs(self.region_box[end_idx] - self.region_box[start_idx])

    @property
    def width(self):
        start_idx = 0 if self.direction == "horizontal" else 1
        end_idx = 2 if self.direction == "horizontal" else 3
        return abs(self.region_box[end_idx] - self.region_box[start_idx])

    def __str__(self) -> str:
        return f"{' '.join([str(span.text) for span in self.spans])}\n"

    def __repr__(self) -> str:
        return f"{' '.join([str(span.text) for span in self.spans])}\n"

    def add_span(self, span: Union[TextSpan, List[TextSpan]]):
        """
        Add a span to the text line.

        Args:
            span (Union[TextSpan, List[TextSpan]]): A single TextSpan object or a list of TextSpan objects.
        """
        if isinstance(span, list):
            self.spans.extend(span)
        else:
            self.spans.append(span)
        self.region_box = self.get_region_box()

    def get_region_box(self):
        """
        Get the region box of the text line.

        Returns:
            list: The region box of the text line.
        """
        if not self.spans:
            return None  # or an empty list, or however you want to handle no spans

        # Initialize min and max values with the first span's box
        x_min, y_min, x_max, y_max = self.spans[0].box

        for span in self.spans:
            x_min = min(x_min, span.box[0])
            y_min = min(y_min, span.box[1])
            x_max = max(x_max, span.box[2])
            y_max = max(y_max, span.box[3])

        return [x_min, y_min, x_max, y_max]

    def get_texts(
        self,
        block_label: str,
        block_text_width: int,
        block_start_coordinate: int,
        block_stop_coordinate: int,
        ori_image,
        text_rec_model=None,
        text_rec_score_thresh=None,
    ):
        """
        Get the text of the text line.

        Args:
            block_label (str): The label of the block.
            block_text_width (int): The width of the block.
            block_start_coordinate (int): The starting coordinate of the block.
            block_stop_coordinate (int): The stopping coordinate of the block.
            ori_image (np.ndarray): The original image.
            text_rec_model (Any): The text recognition model.
            text_rec_score_thresh (float): The text recognition score threshold.

        Returns:
            str: The text of the text line.
        """
        span_box_start_index = 0 if self.direction == "horizontal" else 1
        lines_start_index = 1 if self.direction == "horizontal" else 3
        self.spans.sort(
            key=lambda span: (
                span.box[span_box_start_index] // 2,
                (
                    span.box[lines_start_index]
                    if self.direction == "horizontal"
                    else -span.box[lines_start_index]
                ),
            )
        )
        if "formula" in self.labels:
            sort_index = 0 if self.direction == "horizontal" else 1
            splited_spans = self.split_boxes_by_projection()
            if len(self.spans) != len(splited_spans):
                splited_spans.sort(key=lambda span: span.box[sort_index])
                new_spans = []
                for span in splited_spans:
                    bbox = span.box
                    if span.label == "text":
                        crop_img = ori_image[
                            int(bbox[1]) : int(bbox[3]),
                            int(bbox[0]) : int(bbox[2]),
                        ]
                        crop_img_rec_res = list(text_rec_model([crop_img]))[0]
                        crop_img_rec_score = crop_img_rec_res["rec_score"]
                        crop_img_rec_text = crop_img_rec_res["rec_text"]
                        span.text = crop_img_rec_text
                        if crop_img_rec_score < text_rec_score_thresh:
                            continue
                    new_spans.append(span)
                self.spans = new_spans
        line_text = self.format_line(
            block_text_width,
            block_start_coordinate,
            block_stop_coordinate,
            line_gap_limit=self.height * 1.5,
            block_label=block_label,
        )
        return line_text

    def is_projection_contained(self, box_a, box_b, start_idx, end_idx):
        """Check if box_a completely contains box_b in the x-direction."""
        return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx]

    def split_boxes_by_projection(self, offset=1e-5):
        """
        Check if there is any complete containment in the x-direction
        between the bounding boxes and split the containing box accordingly.

        Args:
            offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
        Returns:
            A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
        """

        new_spans = []
        if self.direction == "horizontal":
            projection_start_index, projection_end_index = 0, 2
        else:
            projection_start_index, projection_end_index = 1, 3

        for i in range(len(self.spans)):
            span = self.spans[i]
            is_split = False
            for j in range(i, len(self.spans)):
                box_b = self.spans[j].box
                box_a, text, label = span.box, span.text, span.label
                if self.is_projection_contained(
                    box_a, box_b, projection_start_index, projection_end_index
                ):
                    is_split = True
                    # Split box_a based on the x-coordinates of box_b
                    if box_a[projection_start_index] < box_b[projection_start_index]:
                        w = (
                            box_b[projection_start_index]
                            - offset
                            - box_a[projection_start_index]
                        )
                        if w > 1:
                            new_bbox = box_a.copy()
                            new_bbox[projection_end_index] = (
                                box_b[projection_start_index] - offset
                            )
                            new_spans.append(
                                TextSpan(
                                    box=np.array(new_bbox),
                                    text=text,
                                    label=label,
                                )
                            )
                    if box_a[projection_end_index] > box_b[projection_end_index]:
                        w = (
                            box_a[projection_end_index]
                            - box_b[projection_end_index]
                            + offset
                        )
                        if w > 1:
                            box_a[projection_start_index] = (
                                box_b[projection_end_index] + offset
                            )
                            span = TextSpan(
                                box=np.array(box_a),
                                text=text,
                                label=label,
                            )
                if j == len(self.spans) - 1 and is_split:
                    new_spans.append(span)
            if not is_split:
                new_spans.append(span)

        return new_spans

    def format_line(
        self,
        block_text_width: int,
        block_start_coordinate: int,
        block_stop_coordinate: int,
        line_gap_limit: int = 10,
        block_label: str = "text",
    ) -> str:
        """
        Format a line of text spans based on layout constraints.

        Args:
            block_text_width (int): The width of the block.
            block_start_coordinate (int): The starting coordinate of the block.
            block_stop_coordinate (int): The stopping coordinate of the block.
            line_gap_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
            block_label (str): The label associated with the entire block. Default is 'text'.
        Returns:
            str: Formatted line of text.
        """
        first_span_box = self.spans[0].box
        last_span_box = self.spans[-1].box

        line_text = ""
        for span in self.spans:
            if span.label == "formula" and block_label != "formula":
                formula_rec = span.text
                if not formula_rec.startswith("$") and not formula_rec.endswith("$"):
                    if len(self.spans) > 1:
                        span.text = f"${span.text}$"
                    else:
                        span.text = f"\n${span.text}$"
            line_text += span.text
            if (
                len(span.text) > 0
                and is_english_letter(line_text[-1])
                or span.label == "formula"
            ):
                line_text += " "

        if self.direction == "horizontal":
            text_stop_index = 2
        else:
            text_stop_index = 3

        if line_text.endswith(" "):
            line_text = line_text[:-1]

        if len(line_text) == 0:
            return ""

        last_char = line_text[-1]

        if (
            not is_english_letter(last_char)
            and not is_non_breaking_punctuation(last_char)
            and not is_numeric(last_char)
        ) or (
            block_stop_coordinate - last_span_box[text_stop_index]
            > block_text_width * 0.3
        ):
            if (
                self.direction == "horizontal"
                and block_stop_coordinate - last_span_box[text_stop_index]
                > line_gap_limit
            ) or (
                self.direction == "vertical"
                and (
                    block_stop_coordinate - last_span_box[text_stop_index]
                    > line_gap_limit
                    or first_span_box[1] - block_start_coordinate > line_gap_limit
                )
            ):
                self.need_new_line = True

        if line_text.endswith("-"):
            line_text = line_text[:-1]
            return line_text

        if (len(line_text) > 0 and is_english_letter(last_char)) or line_text.endswith(
            "$"
        ):
            line_text += " "
        if (
            len(line_text) > 0
            and not is_english_letter(last_char)
            and not is_numeric(last_char)
        ) or self.direction == "vertical":
            if (
                block_stop_coordinate - last_span_box[text_stop_index]
                > block_text_width * 0.3
                and len(line_text) > 0
                and not is_non_breaking_punctuation(last_char)
            ):
                line_text += "\n"
                self.need_new_line = True
        elif (
            block_stop_coordinate - last_span_box[text_stop_index]
            > (block_stop_coordinate - block_start_coordinate) * 0.5
        ):
            line_text += "\n"
            self.need_new_line = True

        return line_text


class LayoutBlock(object):
    """Layout Block Class"""

    def __init__(self, label, bbox, content="") -> None:
        """
        Initialize a LayoutBlock object.

        Args:
            label (str): Label assigned to the block.
            bbox (list): Bounding box coordinates of the block.
            content (str, optional): Content of the block. Defaults to an empty string.
        """
        self.label = label
        self.order_label = None
        self.bbox = list(map(int, bbox))
        self.content = content
        self.seg_start_coordinate = float("inf")
        self.seg_end_coordinate = float("-inf")
        self.width = bbox[2] - bbox[0]
        self.height = bbox[3] - bbox[1]
        self.area = float(self.width) * float(self.height)
        self.num_of_lines = 1
        self.image = None
        self.index = None
        self.order_index = None
        self.text_line_width = 1
        self.text_line_height = 1
        self.child_blocks = []
        self.update_direction()

    def __str__(self) -> str:
        _str = f"\n\n#################\nindex:\t{self.index}\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
        return _str

    def __repr__(self) -> str:
        _str = f"\n\n#################\nindex:\t{self.index}\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
        return _str

    def to_dict(self) -> dict:
        return self.__dict__

    def update_direction(self, direction=None) -> None:
        """
        Update the direction of the block based on its bounding box.

        Args:
            direction (str, optional): Direction of the block. If not provided, it will be determined automatically using the bounding box. Defaults to None.
        """
        if not direction:
            direction = self.get_bbox_direction()
        self.direction = direction
        self.update_direction_info()

    def update_direction_info(self) -> None:
        """Update the direction information of the block based on its direction."""
        if self.direction == "horizontal":
            self.secondary_direction = "vertical"
            self.short_side_length = self.height
            self.long_side_length = self.width
            self.start_coordinate = self.bbox[0]
            self.end_coordinate = self.bbox[2]
            self.secondary_direction_start_coordinate = self.bbox[1]
            self.secondary_direction_end_coordinate = self.bbox[3]
        else:
            self.secondary_direction = "horizontal"
            self.short_side_length = self.width
            self.long_side_length = self.height
            self.start_coordinate = self.bbox[1]
            self.end_coordinate = self.bbox[3]
            self.secondary_direction_start_coordinate = self.bbox[0]
            self.secondary_direction_end_coordinate = self.bbox[2]

    def append_child_block(self, child_block) -> None:
        """
        Append a child block to the current block.

        Args:
            child_block (LayoutBlock): Child block to be added.
        Returns:
            None
        """
        if not self.child_blocks:
            self.ori_bbox = self.bbox.copy()
        x1, y1, x2, y2 = self.bbox
        x1_child, y1_child, x2_child, y2_child = child_block.bbox
        union_bbox = (
            min(x1, x1_child),
            min(y1, y1_child),
            max(x2, x2_child),
            max(y2, y2_child),
        )
        self.bbox = union_bbox
        self.update_direction_info()
        child_blocks = [child_block]
        if child_block.child_blocks:
            child_blocks.extend(child_block.get_child_blocks())
        self.child_blocks.extend(child_blocks)

    def get_child_blocks(self) -> list:
        """Get all child blocks of the current block."""
        self.bbox = self.ori_bbox
        child_blocks = self.child_blocks.copy()
        self.child_blocks = []
        return child_blocks

    def get_centroid(self) -> tuple:
        """Get the centroid of the bounding box of the block."""
        x1, y1, x2, y2 = self.bbox
        centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
        return centroid

    def get_bbox_direction(self, direction_ratio: float = 1.0) -> str:
        """
        Determine if a bounding box is horizontal or vertical.

        Args:
            direction_ratio (float): Ratio for determining direction. Default is 1.0.

        Returns:
            str: "horizontal" or "vertical".
        """
        return (
            "horizontal" if self.width * direction_ratio >= self.height else "vertical"
        )

    def calculate_text_line_direction(
        self, bboxes: List[List[int]], direction_ratio: float = 1.5
    ) -> bool:
        """
        Calculate the direction of the text based on the bounding boxes.

        Args:
            bboxes (list): A list of bounding boxes.
            direction_ratio (float): Ratio for determining direction. Default is 1.5.

        Returns:
            str: "horizontal" or "vertical".
        """

        horizontal_box_num = 0
        for bbox in bboxes:
            if len(bbox) != 4:
                raise ValueError(
                    "Invalid bounding box format. Expected a list of length 4."
                )
            x1, y1, x2, y2 = bbox
            width = x2 - x1
            height = y2 - y1
            horizontal_box_num += 1 if width * direction_ratio >= height else 0

        return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical"

    def group_boxes_into_lines(
        self, ocr_rec_res, line_height_iou_threshold
    ) -> List[TextLine]:
        """
        Group the bounding boxes into lines based on their direction.

        Args:
            ocr_rec_res (dict): The result of OCR recognition.
            line_height_iou_threshold (float): The minimum IOU value required for two spans to belong to the same line.

        Returns:
            list: A list of TextLines.
        """
        rec_boxes = ocr_rec_res["boxes"]
        rec_texts = ocr_rec_res["rec_texts"]
        rec_labels = ocr_rec_res["rec_labels"]

        text_boxes = [
            rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text"
        ]
        direction = self.calculate_text_line_direction(text_boxes)
        self.update_direction(direction)

        spans = [TextSpan(*span) for span in zip(rec_boxes, rec_texts, rec_labels)]

        if not spans:
            return []

        # sort spans by direction
        if self.direction == "vertical":
            spans.sort(
                key=lambda span: span.box[0], reverse=True
            )  # sort by x coordinate
            match_direction = "horizontal"
        else:
            spans.sort(
                key=lambda span: span.box[1], reverse=False
            )  # sort by y coordinate
            match_direction = "vertical"

        lines = []
        current_line = TextLine([spans[0]], direction=self.direction)

        for span in spans[1:]:
            overlap_ratio = calculate_projection_overlap_ratio(
                current_line.region_box, span.box, match_direction, mode="small"
            )

            if overlap_ratio >= line_height_iou_threshold:
                current_line.add_span(span)
            else:
                lines.append(current_line)
                current_line = TextLine([span], direction=self.direction)

        lines.append(current_line)

        if lines and self.direction == "vertical":
            line_heights = np.array([line.height for line in lines])
            min_height = np.min(line_heights)
            max_height = np.max(line_heights)

            # if height is too large, filter out the line
            if max_height > min_height * 2:
                normal_height_threshold = min_height * 1.1
                normal_height_count = np.sum(line_heights < normal_height_threshold)

                # if the number of lines with height less than the threshold is less than 40%, then filter out the line
                if normal_height_count < len(lines) * 0.4:
                    keep_condition = line_heights <= normal_height_threshold
                    lines = [line for line, keep in zip(lines, keep_condition) if keep]

        # calculate the average height of the text line
        if lines:
            line_heights = [line.height for line in lines]
            line_widths = [line.width for line in lines]
            self.text_line_height = np.mean(line_heights)
            self.text_line_width = np.mean(line_widths)
        else:
            self.text_line_height = 0
            self.text_line_width = 0

        return lines

    def update_text_content(
        self,
        image: list,
        ocr_rec_res: dict,
        text_rec_model: Any,
        text_rec_score_thresh: Union[float, None] = None,
    ) -> None:
        """
        Update the text content of the block based on the OCR result.

        Args:
            image (list): The input image.
            ocr_rec_res (dict): The result of OCR recognition.
            text_rec_model (Any): The model used for text recognition.
            text_rec_score_thresh (Union[float, None]): The score threshold for text recognition. If None, use the default setting.

        Returns:
            None
        """

        if len(ocr_rec_res["rec_texts"]) == 0:
            self.content = ""
            return

        lines = self.group_boxes_into_lines(
            ocr_rec_res,
            LINE_SETTINGS.get("line_height_iou_threshold", 0.8),
        )

        # words start coordinate and stop coordinate in the line
        coord_start_idx = 0 if self.direction == "horizontal" else 1
        coord_end_idx = coord_start_idx + 2

        if self.label == "reference":
            rec_boxes = ocr_rec_res["boxes"]
            block_start = min([box[coord_start_idx] for box in rec_boxes])
            block_stop = max([box[coord_end_idx] for box in rec_boxes])
        else:
            block_start = self.bbox[coord_start_idx]
            block_stop = self.bbox[coord_end_idx]

        text_lines = []
        text_width_list = []
        need_new_line_num = 0

        for line_idx, line in enumerate(lines):
            line: TextLine = line
            text_width_list.append(line.width)
            # get text from line
            line_text = line.get_texts(
                block_label=self.label,
                block_text_width=max(text_width_list),
                block_start_coordinate=block_start,
                block_stop_coordinate=block_stop,
                ori_image=image,
                text_rec_model=text_rec_model,
                text_rec_score_thresh=text_rec_score_thresh,
            )

            if line.need_new_line:
                need_new_line_num += 1

            # set segment start and end coordinate
            if line_idx == 0:
                self.seg_start_coordinate = line.spans[0].box[0]
            elif line_idx == len(lines) - 1:
                self.seg_end_coordinate = line.spans[-1].box[2]

            text_lines.append(line_text)

        delim = LINE_SETTINGS["delimiter_map"].get(self.label, "")

        if delim == "":
            content = ""
            pre_line_end = False
            last_char = ""
            for idx, line_text in enumerate(text_lines):
                if len(line_text) == 0:
                    continue

                line: TextLine = lines[idx]
                if pre_line_end:
                    start_gep_len = line.region_box[coord_start_idx] - block_start
                    if (
                        (
                            start_gep_len > line.height * 1.5
                            and not is_english_letter(last_char)
                            and not is_numeric(last_char)
                        )
                        or start_gep_len > (block_stop - block_start) * 0.4
                    ) and not content.endswith("\n"):
                        line_text = "\n" + line_text
                content += f"{line_text}"

                if len(line_text) > 2 and line_text.endswith(" "):
                    last_char = line_text[-2]
                else:
                    last_char = line_text[-1]
                if (
                    len(line_text) > 0
                    and not line_text.endswith("\n")
                    and not is_english_letter(last_char)
                    and not is_non_breaking_punctuation(last_char)
                    and not is_numeric(last_char)
                    and need_new_line_num > len(text_lines) * 0.5
                ) or need_new_line_num > len(text_lines) * 0.6:
                    content += f"\n"
                if (
                    block_stop - line.region_box[coord_end_idx]
                    > (block_stop - block_start) * 0.3
                ):
                    pre_line_end = True
        else:
            content = delim.join(text_lines)

        self.content = content
        self.num_of_lines = len(text_lines)


class LayoutRegion(LayoutBlock):
    """LayoutRegion class"""

    def __init__(
        self,
        bbox,
        blocks: List[LayoutBlock] = [],
    ) -> None:
        """
        Initialize a LayoutRegion object.

        Args:
            bbox (List[int]): The bounding box of the region.
            blocks (List[LayoutBlock]): A list of blocks that belong to this region.
        """
        super().__init__("region", bbox, content="")
        self.bbox = bbox
        self.block_map = {}
        self.direction = "horizontal"
        self.doc_title_block_idxes = []
        self.paragraph_title_block_idxes = []
        self.vision_block_idxes = []
        self.unordered_block_idxes = []
        self.vision_title_block_idxes = []
        self.normal_text_block_idxes = []
        self.euclidean_distance = float(np.inf)
        self.header_block_idxes = []
        self.footer_block_idxes = []
        self.text_line_width = 20
        self.text_line_height = 10
        self.num_of_lines = 10
        self.init_region_info_from_layout(blocks)
        self.update_euclidean_distance()

    def init_region_info_from_layout(self, blocks: List[LayoutBlock]) -> None:
        """Initialize the information about the layout region from the given blocks.

        Args:
            blocks (List[LayoutBlock]): A list of blocks that belong to this region.
        Returns:
            None
        """
        horizontal_normal_text_block_num = 0
        text_line_height_list = []
        text_line_width_list = []
        for idx, block in enumerate(blocks):
            self.block_map[idx] = block
            block.index = idx
            if block.label in BLOCK_LABEL_MAP["header_labels"]:
                self.header_block_idxes.append(idx)
            elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
                self.doc_title_block_idxes.append(idx)
            elif block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]:
                self.paragraph_title_block_idxes.append(idx)
            elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
                self.vision_block_idxes.append(idx)
            elif block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
                self.vision_title_block_idxes.append(idx)
            elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
                self.footer_block_idxes.append(idx)
            elif block.label in BLOCK_LABEL_MAP["unordered_labels"]:
                self.unordered_block_idxes.append(idx)
            else:
                self.normal_text_block_idxes.append(idx)
                text_line_height_list.append(block.text_line_height)
                text_line_width_list.append(block.text_line_width)
                if block.direction == "horizontal":
                    horizontal_normal_text_block_num += 1
        direction = (
            "horizontal"
            if horizontal_normal_text_block_num
            >= len(self.normal_text_block_idxes) * 0.5
            else "vertical"
        )
        self.update_direction(direction)
        self.text_line_width = (
            np.mean(text_line_width_list) if text_line_width_list else 20
        )
        self.text_line_height = (
            np.mean(text_line_height_list) if text_line_height_list else 10
        )

    def update_euclidean_distance(self):
        """Update euclidean distance between each block and the reference point"""
        blocks: List[LayoutBlock] = list(self.block_map.values())
        if self.direction == "horizontal":
            ref_point = (0, 0)
            block_distance = [
                caculate_euclidean_dist((block.bbox[0], block.bbox[1]), ref_point)
                for block in blocks
            ]
        else:
            ref_point = (self.bbox[2], 0)
            block_distance = [
                caculate_euclidean_dist((block.bbox[2], block.bbox[1]), ref_point)
                for block in blocks
            ]
        self.euclidean_distance = min(block_distance) if len(block_distance) > 0 else 0

    def update_direction(self, direction=None):
        """
        Update the direction of the layout region.

        Args:
            direction (str): The new direction of the layout region.
        """
        super().update_direction(direction=direction)
        if self.direction == "horizontal":
            self.direction_start_index = 0
            self.direction_end_index = 2
            self.secondary_direction_start_index = 1
            self.secondary_direction_end_index = 3
            self.secondary_direction = "vertical"
        else:
            self.direction_start_index = 1
            self.direction_end_index = 3
            self.secondary_direction_start_index = 0
            self.secondary_direction_end_index = 2
            self.secondary_direction = "horizontal"

        self.direction_center_coordinate = (
            self.bbox[self.direction_start_index] + self.bbox[self.direction_end_index]
        ) / 2
        self.secondary_direction_center_coordinate = (
            self.bbox[self.secondary_direction_start_index]
            + self.bbox[self.secondary_direction_end_index]
        ) / 2
