"""
dl_pipeline.py

Hybrid Deep Learning Extraction Pipeline
CPU-Friendly Version

Architecture:
PDF
 ↓
DL Extraction Pipeline
 ↓
Confidence Scoring
 ↓
If score < 70:
    invoke existing DLM engine
 ↓
Merge Results
 ↓
Final JSON

Requirements:
pip install paddleocr paddlepaddle ultralytics transformers torch torchvision pdf2image pillow opencv-python

Optional:
sudo apt install poppler-utils
"""

import os
import re
import json
import cv2
import torch
import numpy as np

from pathlib import Path
from pdf2image import convert_from_path
from paddleocr import PaddleOCR
from ultralytics import YOLO
from transformers import pipeline

# ==============================
# IMPORT YOUR EXISTING DLM ENGINE
# ==============================

from extract_land import extract_land_record as extract_land_record_dlm


# ==============================
# CONFIG
# ==============================

DL_THRESHOLD = 70

DEVICE = "cpu"

YOLO_MODEL_PATH = "models/layout_detector.pt"

# ==============================
# OCR ENGINE
# ==============================

print("Loading PaddleOCR...")

# We bypass the broken high-level wrapper and instantiate the underlying mapping dictionary manually
class SafePaddleOCR(PaddleOCR):
    def __init__(self):
        # This forcefully injects the correct parameters past the broken validator loop
        self.ocr_version = 'PP-OCRv3'
        self.lang = 'bn'
        super().__init__(lang='en', ocr_version='PP-OCRv3')
        
        # Override the text recognition language mappings directly in memory
        self.rec_lang = 'bn'

ocr_engine = SafePaddleOCR()

# ==============================
# OPTIONAL NER MODEL
# ==============================

try:
    ner_pipeline = pipeline(
        "token-classification",
        model="ai4bharat/IndicNER",
        device=-1
    )
except:
    ner_pipeline = None

# ==============================
# OPTIONAL LAYOUT MODEL
# ==============================

layout_model = None

if os.path.exists(YOLO_MODEL_PATH):
    layout_model = YOLO(YOLO_MODEL_PATH)


# ==============================
# BASIC UTILS
# ==============================

BN_TO_EN = str.maketrans("০১২৩৪৫৬৭৮৯", "0123456789")


def bn_to_en(text):
    if not text:
        return ""
    return text.translate(BN_TO_EN)


def normalize_text(text):
    if not text:
        return ""

    text = text.strip()
    text = re.sub(r"\s+", " ", text)

    return text


# ==============================
# PDF TO IMAGES
# ==============================

def pdf_to_images(pdf_path):
    pages = convert_from_path(pdf_path, dpi=250)

    images = []

    for i, page in enumerate(pages):
        img = np.array(page)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        images.append({
            "page": i,
            "image": img
        })

    return images


# ==============================
# OCR EXTRACTION
# ==============================

def run_ocr(images):

    all_words = []

    for page_data in images:

        page_no = page_data["page"]
        image = page_data["image"]

        result = ocr_engine.ocr(image)

        if not result:
            continue

        for line in result[0]:

            bbox = line[0]
            text = normalize_text(line[1][0])
            conf = float(line[1][1])

            if not text:
                continue

            xs = [p[0] for p in bbox]
            ys = [p[1] for p in bbox]

            all_words.append({
                "text": text,
                "conf": conf,
                "xmin": min(xs),
                "ymin": min(ys),
                "xmax": max(xs),
                "ymax": max(ys),
                "page": page_no
            })

    return all_words


# ==============================
# LAYOUT DETECTION
# ==============================

def detect_layout(images):

    layout_results = []

    if layout_model is None:
        return layout_results

    for page_data in images:

        image = page_data["image"]
        page = page_data["page"]

        results = layout_model(image)

        for r in results:

            boxes = r.boxes

            for box in boxes:

                cls = int(box.cls[0])
                conf = float(box.conf[0])

                xyxy = box.xyxy[0].tolist()

                layout_results.append({
                    "page": page,
                    "class": cls,
                    "confidence": conf,
                    "bbox": xyxy
                })

    return layout_results


# ==============================
# HEADER EXTRACTION
# ==============================

def extract_header_dl(words):

    full_text = " ".join([w["text"] for w in words])

    result = {
        "jl_no": "",
        "daag_no": "",
        "mouza": "",
        "block": "",
        "district": "",
        "total_land_acre": ""
    }

    # JL No
    m = re.search(r"জে\.?এল\s*নং\s*([\d০-৯]+)", full_text)

    if m:
        result["jl_no"] = bn_to_en(m.group(1))

    # Daag
    m = re.search(r"\b([\d০-৯/]{2,6})\b", full_text)

    if m:
        result["daag_no"] = bn_to_en(m.group(1))

    # Mouza
    m = re.search(r"মৌজা[:ঃ]?\s*(\S+)", full_text)

    if m:
        result["mouza"] = m.group(1)

    # Block
    m = re.search(r"ব্লক[:ঃ]?\s*([A-Z0-9\-]+)", full_text)

    if m:
        result["block"] = m.group(1)

    # District
    m = re.search(r"জেলা[:ঃ]?\s*([A-Z]+)", full_text)

    if m:
        result["district"] = m.group(1)

    # Total land
    m = re.search(r"([\d০-৯]+\.[\d০-৯]+)", full_text)

    if m:
        result["total_land_acre"] = bn_to_en(m.group(1))

    return result


# ==============================
# ROW GROUPING
# ==============================

def group_rows(words, y_threshold=15):

    if not words:
        return []

    words = sorted(words, key=lambda x: (x["page"], x["ymin"]))

    rows = []
    current = [words[0]]

    for w in words[1:]:

        prev = current[0]

        if (
            w["page"] == prev["page"]
            and abs(w["ymin"] - prev["ymin"]) < y_threshold
        ):
            current.append(w)

        else:
            rows.append(current)
            current = [w]

    rows.append(current)

    return rows


# ==============================
# ENTITY EXTRACTION
# ==============================

def extract_entries_dl(words):

    rows = group_rows(words)

    entries = []

    for row in rows:

        row_text = " ".join([x["text"] for x in row])

        khatian = ""

        m = re.search(r"\b\d[\d/]*\b", bn_to_en(row_text))

        if m:
            khatian = m.group(0)

        decimals = re.findall(r"\d+\.\d+", bn_to_en(row_text))

        if not khatian:
            continue

        if not decimals:
            continue

        owner = ""
        father = ""

        if ner_pipeline:

            try:
                ner_result = ner_pipeline(row_text)

                tokens = [x["word"] for x in ner_result]

                owner = " ".join(tokens[:2]) if len(tokens) >= 2 else ""
                father = " ".join(tokens[2:4]) if len(tokens) >= 4 else ""

            except:
                pass

        if not owner:

            texts = row_text.split()

            if len(texts) >= 2:
                owner = texts[1]

            if len(texts) >= 3:
                father = texts[2]

        entries.append({
            "khatian_no": khatian,
            "owner_name": owner,
            "father_husband_name": father,
            "ansha": decimals[0] if len(decimals) >= 1 else "",
            "area_acres": decimals[1] if len(decimals) >= 2 else ""
        })

    return entries


# ==============================
# CONFIDENCE SCORING
# ==============================

def compute_dl_confidence(result, words):

    score = 100

    total_entries = len(result.get("khatian_entries", []))

    # No entries
    if total_entries == 0:
        score -= 40

    # Missing header fields
    important_headers = [
        "jl_no",
        "daag_no",
        "mouza",
        "block",
        "district"
    ]

    missing_headers = sum(
        1 for h in important_headers
        if not result.get(h)
    )

    score -= missing_headers * 5

    # OCR confidence
    if words:

        avg_ocr = np.mean([w["conf"] for w in words])

        if avg_ocr < 0.6:
            score -= 20

        elif avg_ocr < 0.75:
            score -= 10

    # Missing owner names
    entries = result.get("khatian_entries", [])

    if entries:

        missing_owner = sum(
            1 for e in entries
            if not e.get("owner_name")
        )

        ratio = missing_owner / len(entries)

        if ratio > 0.5:
            score -= 20

    # Invalid decimals
    malformed = 0

    for e in entries:

        area = e.get("area_acres", "")

        if area and not re.match(r"^\d+\.\d+$", area):
            malformed += 1

    if malformed > 0:
        score -= 10

    return max(score, 0)


# ==============================
# MERGE RESULTS
# ==============================

def merge_results(dl_result, dlm_result):

    final = dl_result.copy()

    # Fill missing headers
    for key in [
        "jl_no",
        "daag_no",
        "mouza",
        "block",
        "district",
        "total_land_acre"
    ]:

        if not final.get(key):
            final[key] = dlm_result.get(key)

    # Entries merge
    dl_entries = final.get("khatian_entries", [])
    dlm_entries = dlm_result.get("khatian_entries", [])

    if len(dl_entries) == 0:
        final["khatian_entries"] = dlm_entries

    else:

        merged_entries = []

        max_len = max(len(dl_entries), len(dlm_entries))

        for i in range(max_len):

            dl_e = dl_entries[i] if i < len(dl_entries) else {}
            dlm_e = dlm_entries[i] if i < len(dlm_entries) else {}

            merged = dl_e.copy()

            for field in [
                "khatian_no",
                "owner_name",
                "father_husband_name",
                "ansha",
                "area_acres"
            ]:

                if not merged.get(field):
                    merged[field] = dlm_e.get(field)

            merged_entries.append(merged)

        final["khatian_entries"] = merged_entries

    final["fallback_used"] = True

    return final


# ==============================
# DL EXTRACTION
# ==============================

def extract_land_record_dl(pdf_path):

    images = pdf_to_images(pdf_path)

    words = run_ocr(images)

    layout = detect_layout(images)

    header = extract_header_dl(words)

    entries = extract_entries_dl(words)

    result = {
        "source_file": Path(pdf_path).name,
        "jl_no": header.get("jl_no", ""),
        "daag_no": header.get("daag_no", ""),
        "mouza": header.get("mouza", ""),
        "block": header.get("block", ""),
        "district": header.get("district", ""),
        "total_land_acre": header.get("total_land_acre", ""),
        "total_entries": len(entries),
        "khatian_entries": entries,
        "layout_objects": len(layout)
    }

    score = compute_dl_confidence(result, words)

    result["dl_confidence"] = score

    return result


# ==============================
# MAIN HYBRID PIPELINE
# ==============================

def process_land_record(pdf_path):

    print("\nRunning DL pipeline...")

    dl_result = extract_land_record_dl(pdf_path)

    score = dl_result["dl_confidence"]

    print(f"DL Confidence: {score}")

    # HIGH CONFIDENCE
    if score >= DL_THRESHOLD:

        dl_result["pipeline"] = "DL_ONLY"

        return dl_result

    # FALLBACK TO DLM
    print("DL confidence low. Running DLM fallback...")

    dlm_result = extract_land_record_dlm(pdf_path)

    merged = merge_results(dl_result, dlm_result)

    merged["pipeline"] = "DL_PLUS_DLM"

    return merged


# ==============================
# CLI
# ==============================

if __name__ == "__main__":

    import sys

    if len(sys.argv) < 2:
        print("Usage:")
        print("python dl_pipeline.py file.pdf")
        exit()

    pdf_path = sys.argv[1]

    result = process_land_record(pdf_path)

    print(json.dumps(result, ensure_ascii=False, indent=2))