from pathlib import Path
from PIL import Image, ImageOps, ImageFilter
import numpy as np
import io

MIN_SIDE = 1200

def _ensure_pil(img_or_path):
    if isinstance(img_or_path, (str, Path)):
        img = Image.open(str(img_or_path)).convert("RGB")
    elif isinstance(img_or_path, bytes):
        img = Image.open(io.BytesIO(img_or_path)).convert("RGB")
    else:
        img = img_or_path.convert("RGB")
    return img

def _resize_if_small(pil_img):
    w, h = pil_img.size
    max_side = max(w, h)
    if max_side < MIN_SIDE:
        scale = MIN_SIDE / max_side
        new_size = (int(w * scale), int(h * scale))
        return pil_img.resize(new_size, Image.Resampling.LANCZOS)
    return pil_img

def _to_bw(pil_img):
    gray = ImageOps.grayscale(pil_img)
    gray = gray.filter(ImageFilter.ModeFilter(size=3))
    gray = ImageOps.autocontrast(gray)
    arr = np.array(gray).astype(np.uint8)
    hist, _ = np.histogram(arr.flatten(), bins=256, range=(0,256))
    total = arr.size
    sum_total = (np.arange(256) * hist).sum()
    sumB = 0
    wB = 0
    max_var = 0
    threshold = 128
    for i in range(256):
        wB += hist[i]
        if wB == 0: continue
        wF = total - wB
        if wF == 0: break
        sumB += i * hist[i]
        mB = sumB / wB
        mF = (sum_total - sumB) / wF
        var_between = wB * wF * (mB - mF) ** 2
        if var_between > max_var:
            max_var = var_between
            threshold = i
    bw = (arr > threshold).astype(np.uint8) * 255
    pil_bw = Image.fromarray(bw)
    return pil_bw

def preprocess(path_or_pil, save_debug=False, debug_dir=Path("data/output/debug_images"), idx=0):
    pil_rgb = _ensure_pil(path_or_pil)
    pil_rgb = _resize_if_small(pil_rgb)
    pil_bw = _to_bw(pil_rgb)
    np_rgb = np.array(pil_rgb)
    np_gray = np.array(pil_bw)
    if save_debug:
        debug_dir.mkdir(parents=True, exist_ok=True)
        debug_path = debug_dir / f"{idx:05d}.png"
        pil_bw.save(debug_path)
    return pil_rgb, pil_bw, np_rgb, np_gray
