"""
auto_georef_from_one_anchor.py

Usage:
  - Put the mouza PNG in the same folder (e.g. 'annotated_map.png' or original image)
  - Edit INPUT variables below (IMAGE_PATH, ANCHOR_PIXEL, ANCHOR_LONLAT)
  - Run: python auto_georef_from_one_anchor.py

Requirements (install):
  pip install numpy opencv-python pillow requests rasterio shapely pyproj simplekml
  # and a system GDAL (gdalwarp/gdal_translate) or install rasterio with GDAL support:
  # on Ubuntu: sudo apt install gdal-bin libgdal-dev
"""

import os, math, json, tempfile, sys
from io import BytesIO
import numpy as np
import cv2
from PIL import Image
import requests
from math import radians, cos, sin
import rasterio
from rasterio.transform import from_origin
from pyproj import Transformer
import simplekml
from shapely.geometry import mapping, shape, Polygon

# ---------- USER INPUT ----------
IMAGE_PATH = "/var/www/html/kml_chatgpt/annotated_map.png"   # your annotated mouza PNG (or original img)
# pixel coordinates of the Point B circle that you marked on the annotated map
# Use the pixel coordinates that correspond to point B. If you don't know those,
# the script will try to approximate using center of the circle by visual estimate.
ANCHOR_PIXEL = (1850, 500)        # (x, y) pixel on the image -> change if needed
ANCHOR_LONLAT = (87.6680856, 23.9426366)  # (lon, lat) you gave (lon,lat order)
OUT_PREFIX = "mouza_auto"
# ---------- END USER INPUT ----------

# Satellite tile settings (Esri World Imagery)
TILE_URL = "https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}"

# Parameters
ZOOM = 17         # tile zoom for detail; increase for higher resolution (costlier)
SEARCH_ROT_DEG = 10   # +/- degrees to search for rotation
SEARCH_SCALE = (0.8, 1.4)  # scale multiplier search range
SCALE_STEPS = 9
ROT_STEPS = 25
MATCH_BOX = 1200   # px area radius around anchor in reference to use for matching

# helpers
def deg2num(lat_deg, lon_deg, zoom):
    lat_rad = math.radians(lat_deg)
    n = 2.0 ** zoom
    xtile = int((lon_deg + 180.0) / 360.0 * n)
    ytile = int((1.0 - math.log(math.tan(lat_rad) + (1 / math.cos(lat_rad))) / math.pi) / 2.0 * n)
    return (xtile, ytile)

def tile_xy_bounds(x, y, z):
    # returns lon/lat bounds of tile
    n = 2.0 ** z
    lon1 = x / n * 360.0 - 180.0
    lon2 = (x+1) / n * 360.0 - 180.0
    lat1 = math.degrees(math.atan(math.sinh(math.pi*(1 - 2*y/n))))
    lat2 = math.degrees(math.atan(math.sinh(math.pi*(1 - 2*(y+1)/n))))
    return (lon1, lat2, lon2, lat1)  # left, bottom, right, top

def fetch_tile(z, x, y):
    url = TILE_URL.format(z=z, x=x, y=y)
    r = requests.get(url, timeout=20)
    r.raise_for_status()
    return Image.open(BytesIO(r.content)).convert("RGB")

def build_reference_mosaic(center_lon, center_lat, zoom=17, radius_tiles=2):
    # build mosaic of (2*radius+1)^2 tiles around center
    cx, cy = deg2num(center_lat, center_lon, zoom)
    tiles = []
    for dy in range(-radius_tiles, radius_tiles+1):
        row_imgs = []
        for dx in range(-radius_tiles, radius_tiles+1):
            tx, ty = cx + dx, cy + dy
            try:
                img = fetch_tile(zoom, tx, ty)
            except Exception as e:
                # create blank
                img = Image.new("RGB", (256,256), (255,255,255))
            row_imgs.append(np.array(img))
        tiles.append(np.hstack(row_imgs))
    mosaic = np.vstack(tiles)
    # compute mosaic bounding box in lon/lat
    tl = tile_xy_bounds(cx-radius_tiles, cy-radius_tiles, zoom)
    br = tile_xy_bounds(cx+radius_tiles, cy+radius_tiles, zoom)
    left = tl[0]; top = tl[3]; right = br[2]; bottom = br[1]
    return mosaic, (left, top, right, bottom)

def image_edges_gray(img_np):
    gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
    # increase contrast
    gray = cv2.equalizeHist(gray)
    edges = cv2.Canny(gray, 50,150)
    # dilate a bit
    edges = cv2.dilate(edges, np.ones((3,3), np.uint8))
    return edges

def compute_score(fixed_edges, moving_edges, x_off, y_off):
    # compute correlation score for overlapping region
    fh, fw = fixed_edges.shape
    mh, mw = moving_edges.shape
    # compute overlapping slice
    x1 = max(0, x_off)
    y1 = max(0, y_off)
    x2 = min(fw, x_off+mw)
    y2 = min(fh, y_off+mh)
    if x2<=x1 or y2<=y1:
        return -1.0
    # slices
    fixed_slice = fixed_edges[y1:y2, x1:x2]
    mov_x1 = x1 - x_off; mov_y1 = y1 - y_off
    moving_slice = moving_edges[mov_y1:mov_y1+(y2-y1), mov_x1:mov_x1+(x2-x1)]
    # normalized cross-correlation
    fixed_norm = (fixed_slice.astype(np.float32) - fixed_slice.mean()) 
    mov_norm = (moving_slice.astype(np.float32) - moving_slice.mean())
    denom = (np.sqrt((fixed_norm**2).sum()) * np.sqrt((mov_norm**2).sum()) + 1e-6)
    score = (fixed_norm*mov_norm).sum() / denom
    return score

def try_search_transform(mosaic_rgb, mouza_rgb, anchor_px, anchor_geo, zoom):
    # prepare edges
    ref_edges = image_edges_gray(mosaic_rgb)
    mouza_edges = image_edges_gray(mouza_rgb)
    # crop around anchor in ref image
    # compute anchor in pixel coords within mosaic
    # find anchor location in mosaic using lat->pixel calc
    left, top, right, bottom = ref_bbox
    # tile zoom -> mosaic pixel mapping:
    ref_h, ref_w = ref_edges.shape
    # compute linear mapping lon->x,lat->y
    def lon_to_x(lon):
        return int((lon - left) / (right - left) * ref_w)
    def lat_to_y(lat):
        return int((top - lat) / (top - bottom) * ref_h)
    anchor_lon, anchor_lat = anchor_geo
    ref_anchor_x = lon_to_x(anchor_lon)
    ref_anchor_y = lat_to_y(anchor_lat)
    # crop reference region
    bx1 = max(0, ref_anchor_x - MATCH_BOX)
    by1 = max(0, ref_anchor_y - MATCH_BOX)
    bx2 = min(ref_w, ref_anchor_x + MATCH_BOX)
    by2 = min(ref_h, ref_anchor_y + MATCH_BOX)
    ref_crop = ref_edges[by1:by2, bx1:bx2]
    # moving: center on anchor pixel
    mh, mw = mouza_edges.shape
    mx_anchor = int(anchor_px[0]); my_anchor = int(anchor_px[1])
    # create moving patch large enough
    mov_box = (0, 0, mw, mh)
    # search over rotation & scale
    best = {"score": -9e9}
    scales = np.linspace(SEARCH_SCALE[0], SEARCH_SCALE[1], SCALE_STEPS)
    rots = np.linspace(-SEARCH_ROT_DEG, SEARCH_ROT_DEG, ROT_STEPS)
    for s in scales:
        # scaled mouza
        new_w = max(10, int(mw * s))
        new_h = max(10, int(mh * s))
        mov_scaled = cv2.resize(mouza_edges, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        # anchor position in scaled
        ax = int(mx_anchor * s)
        ay = int(my_anchor * s)
        for r in rots:
            # rotate around anchor
            M = cv2.getRotationMatrix2D((ax, ay), r, 1.0)
            mov_rot = cv2.warpAffine(mov_scaled, M, (new_w, new_h), flags=cv2.INTER_LINEAR, borderValue=0)
            # Now we want to overlay mov_rot such that its anchor (ax,ay) aligns with ref_crop center
            cx = (ref_crop.shape[1]//2) - ax
            cy = (ref_crop.shape[0]//2) - ay
            score = compute_score(ref_crop, mov_rot, cx, cy)
            if score > best["score"]:
                best.update({
                    "score": score, "scale": s, "rot": r, "cx": cx, "cy": cy,
                    "mov_w": new_w, "mov_h": new_h, "ax": ax, "ay": ay,
                    "ref_bbox": (bx1,by1,bx2,by2)
                })
    return best, ref_crop

# ---------- MAIN ----------
if __name__ == "__main__":
    if not os.path.exists(IMAGE_PATH):
        print("Image not found:", IMAGE_PATH); sys.exit(1)

    # load mouza image
    mouza = Image.open(IMAGE_PATH).convert("RGB")
    mouza_np = np.array(mouza)

    print("Downloading satellite mosaic around anchor (this may take a few seconds)...")
    lon, lat = ANCHOR_LONLAT
    mosaic_np, ref_coords = build_reference_mosaic(lon, lat, zoom=ZOOM, radius_tiles=2)
    # set global ref_bbox for function
    global ref_bbox
    ref_bbox = ref_coords  # left,top,right,bottom

    print("Computing best transform by searching rotation/scale around anchor...")
    best, ref_crop = try_search_transform(mosaic_np, mouza_np, ANCHOR_PIXEL, (lon, lat), ZOOM)
    print("Best match score:", best["score"])
    print("scale:", best["scale"], "rotation(deg):", best["rot"])

    # derive approximate geotransform:
    # we computed that anchor pixel (mx_anchor,my_anchor) maps to ref_crop center
    # compute mapping from mouza pixel coords -> lon/lat using scale & rotation & reference origin
    s = best["scale"]; rdeg = best["rot"]
    # compute anchor positions in various spaces
    ref_w, ref_h = mosaic_np.shape[1], mosaic_np.shape[0]
    left,top,right,bottom = ref_bbox
    def lon_from_refx(xpix):
        return left + (xpix / ref_w) * (right - left)
    def lat_from_refy(ypix):
        return top - (ypix / ref_h) * (top - bottom)
    # ref center px where anchor was mapped:
    ref_anchor_x = int((lon - left) / (right - left) * ref_w)
    ref_anchor_y = int((top - lat) / (top - bottom) * ref_h)
    # in best case, anchor of scaled+rotated mouza aligns to ref_anchor_x,ref_anchor_y
    ax = int(ANCHOR_PIXEL[0] * s)
    ay = int(ANCHOR_PIXEL[1] * s)
    # rotation matrix for inverse mapping from mouza pixel to ref pixels:
    theta = math.radians(best["rot"])
    cos_t, sin_t = math.cos(theta), math.sin(theta)
    # compute transform: for a pixel (px,py) in original mouza => scaled rotated => ref pixel:
    def mouza_pixel_to_lonlat(px, py):
        # scale
        sx = px * s; sy = py * s
        # rotate about anchor (ax,ay)
        dx = sx - ax; dy = sy - ay
        rx = cos_t * dx - sin_t * dy
        ry = sin_t * dx + cos_t * dy
        ref_x = ref_anchor_x + rx
        ref_y = ref_anchor_y + ry
        return lon_from_refx(ref_x), lat_from_refy(ref_y)

    # produce a small set of GCPs (corners + some interior points)
    h, w = mouza_np.shape[0], mouza_np.shape[1]
    sample_pixels = [
        (0,0), (w-1,0), (0,h-1), (w-1,h-1),
        (ANCHOR_PIXEL[0], ANCHOR_PIXEL[1]),
        (w//2, h//2),
        (int(w*0.25), int(h*0.75))
    ]
    gcp_list = []
    for px,py in sample_pixels:
        lonlat = mouza_pixel_to_lonlat(px,py)
        gcp_list.append({"pixel_x": int(px), "pixel_y": int(py), "lon": float(lonlat[0]), "lat": float(lonlat[1])})

    # save transform & GCPs
    out_json = OUT_PREFIX + "_best_transform.json"
    with open(out_json, "w") as f:
        json.dump({"anchor_pixel": ANCHOR_PIXEL, "anchor_lonlat": ANCHOR_LONLAT,
                   "best": best, "gcp_samples": gcp_list, "ref_bbox": ref_bbox}, f, indent=2)
    print("Saved best transform info to", out_json)

    # create a georeferenced GeoTIFF by writing an approximate affine transform
    # We'll compute affine (assuming no shear) from mapping of two points: anchor and center
    # anchor image px (ax0) maps to lon/lat (lon0,lat0)
    # center image px (cx,cy) maps to lon1,lat1
    lon0, lat0 = mouza_pixel_to_lonlat(ANCHOR_PIXEL[0], ANCHOR_PIXEL[1])
    lon1, lat1 = mouza_pixel_to_lonlat(int(w//2), int(h//2))
    # compute pixel size degrees/pixel
    dx_deg = (lon1 - lon0) / ((w//2) - ANCHOR_PIXEL[0] + 1e-9)
    dy_deg = (lat1 - lat0) / ((h//2) - ANCHOR_PIXEL[1] + 1e-9)
    # approximate geotransform (top-left lon/lat)
    # determine top-left by mapping (0,0)
    lon_tl, lat_tl = mouza_pixel_to_lonlat(0,0)
    # create geotiff with rasterio (WGS84)
    out_tif = OUT_PREFIX + "_georef.tif"
    transform = from_origin(lon_tl, lat_tl, abs(dx_deg), abs(dy_deg))
    # write with rasterio (note: orientation may need flip; this is approximate)
    import rasterio
    dtype = mouza_np.dtype
    channels = mouza_np.shape[2]
    with rasterio.open(out_tif, 'w', driver='GTiff', height=h, width=w, count=channels,
                       dtype=dtype, crs='EPSG:4326', transform=transform) as dst:
        for i in range(channels):
            dst.write(mouza_np[:,:,i], i+1)
    print("Wrote approximate georeferenced GeoTIFF:", out_tif)

    # create KML superoverlay using simplekml as an image overlay (not tiled)
    kml = simplekml.Kml()
    # compute polygon coordinates of image bounds in lonlat
    corners = [mouza_pixel_to_lonlat(0,0),
               mouza_pixel_to_lonlat(w,0),
               mouza_pixel_to_lonlat(w,h),
               mouza_pixel_to_lonlat(0,h)]
    # simplekml expects (lat,lon)
    kml_poly = kml.newgroundoverlay(name="Mouza overlay")
    # create bbox
    lons = [c[0] for c in corners]; lats = [c[1] for c in corners]
    north = max(lats); south = min(lats); east = max(lons); west = min(lons)
    kml_poly.icon.href = os.path.abspath(IMAGE_PATH)
    kml_poly.latlonbox.north = north
    kml_poly.latlonbox.south = south
    kml_poly.latlonbox.east = east
    kml_poly.latlonbox.west = west
    kml_path = OUT_PREFIX + "_overlay.kml"
    kml.save(kml_path)
    print("Saved KML overlay:", kml_path)

    print("\nDone. Files produced:")
    print(" - Approx GeoTIFF:", out_tif)
    print(" - KML overlay:", kml_path)
    print(" - Best transform JSON:", out_json)
    print("\nIf alignment looks poor, please provide 1-2 more exact anchors (pixel->lonlat) or use the GCP HTML picker to give 4-6 pairs.")
