from pathlib import Path
from typing import List
import fitz  # PyMuPDF
from pdf2image import convert_from_path
from concurrent.futures import ThreadPoolExecutor, as_completed
from ocr.preproc import preprocess
from ocr.engines import OCREngine

_engine = None

def get_shared_engine(prefer="tesseract", lang="en", tesseract_cmd=None):
    global _engine
    if _engine is None:
        _engine = OCREngine(prefer=prefer, lang=lang, tesseract_cmd=tesseract_cmd)
    return _engine

def pdf_has_text_per_page(pdf_path: Path) -> List[bool]:
    try:
        doc = fitz.open(str(pdf_path))
        return [bool(page.get_text().strip()) for page in doc]
    except Exception:
        return []

def pdf_extract_text_per_page(pdf_path: Path) -> List[str]:
    doc = fitz.open(str(pdf_path))
    return [page.get_text("text") or "" for page in doc]

def ocr_image_page_from_pil(pil_image, page_idx: int, prefer="tesseract", lang="en", save_debug=False):
    pil_rgb, pil_bw, np_rgb, np_gray = preprocess(pil_image, save_debug=save_debug, idx=page_idx)
    engine = get_shared_engine(prefer=prefer, lang=lang)
    res = engine.run(pil_rgb=pil_rgb, pil_bw=pil_bw, np_rgb=np_rgb, np_gray=np_gray)
    return res.text

def pdf_ocr_pages_via_images(pdf_path: Path, dpi: int = 300, max_workers: int = 4, prefer="tesseract", lang="en", save_debug=False) -> List[str]:
    pages = convert_from_path(str(pdf_path), dpi=dpi)
    page_texts = [None] * len(pages)
    with ThreadPoolExecutor(max_workers=max_workers) as exe:
        futures = {exe.submit(ocr_image_page_from_pil, pil, i, prefer, lang, save_debug): i for i, pil in enumerate(pages)}
        for f in as_completed(futures):
            i = futures[f]
            try:
                page_texts[i] = f.result()
            except Exception as e:
                page_texts[i] = f"[ERROR on page {i}: {e}]"
    return [t or "" for t in page_texts]

def process_pdf_pages(pdf_path: Path, dpi=300, max_workers=4, prefer="tesseract", lang="en", save_debug=False) -> List[str]:
    pdf_path = Path(pdf_path)
    try:
        text_flags = pdf_has_text_per_page(pdf_path)
        if text_flags:
            raw_texts = pdf_extract_text_per_page(pdf_path)
            ocr_texts = pdf_ocr_pages_via_images(pdf_path, dpi=dpi, max_workers=max_workers, prefer=prefer, lang=lang, save_debug=save_debug)
            merged = []
            for i in range(len(raw_texts)):
                merged.append(raw_texts[i] if raw_texts[i].strip() else ocr_texts[i])
            return merged
        else:
            return pdf_ocr_pages_via_images(pdf_path, dpi=dpi, max_workers=max_workers, prefer=prefer, lang=lang, save_debug=save_debug)
    except Exception:
        return pdf_ocr_pages_via_images(pdf_path, dpi=dpi, max_workers=max_workers, prefer=prefer, lang=lang, save_debug=save_debug)
