from fastapi import FastAPI, Query
from typing import Optional
import json
from db import get_tables, get_columns
from graph import build_graph
from ml import predict_joinability, best_path
from join_generator import generate_multi_join_sql
from semantic_rename import suggest_column_name

app = FastAPI()

class RenameRequest(BaseModel):
    raw: str
    readable: str

class TableRenameRequest(BaseModel):
    raw: str
    readable: str


@app.get("/tables")
def tables():
    return get_tables()

@app.get("/columns")
def columns():
    return get_columns()

@app.get("/graph")
def graph():
    G = build_graph()
    return {"nodes": list(G.nodes), "edges": list(G.edges)}

@app.get("/predict_join")
def predict(t1: str, t2: str):
    score = predict_joinability(t1, t2)
    return {"joinability_score": score}

@app.get("/best_path")
def path(source: str, target: str):
    return {"best_path": best_path(source, target)}

@app.get("/generate_multi_join_sql_filtered")
def multi_join_sql_filtered(
    source: str,
    targets: str,
    columns: Optional[str] = Query(None, description="Format: table1.col1,col2;table2.col3,col4"),
    where: Optional[str] = Query(None, description="JSON: {\"table\": {\"col\": {\"op\": \"=\", \"value\": \"val\"}}}"),
    order: Optional[str] = Query(None, description="JSON: [[\"table\", \"col\", \"asc\"]]"),
    limit: Optional[int] = Query(None),
    offset: Optional[int] = Query(None)
):
    target_list = targets.split(",")
    column_filters = {}
    if columns:
        for group in columns.split(";"):
            parts = group.split(".")
            table = parts[0]
            cols = parts[1].split(",") if len(parts) > 1 else ["*"]
            column_filters[table] = cols

    where_filters = json.loads(where) if where else {}
    order_by = json.loads(order) if order else []

    sql = generate_multi_join_sql(
        source,
        target_list,
        column_filters,
        where_filters,
        order_by,
        limit,
        offset
    )
    return {"sql": sql}

@app.get("/suggest_column_name")
def suggest(raw: str):
    name, score = suggest_column_name(raw)
    return {"suggested_name": name, "confidence": round(score, 3)}

@app.get("/columns_with_suggestions")
def columns_with_suggestions():
    raw = get_columns()
    result = {}
    for table, cols in raw.items():
        result[table] = [
            {"original": col, "suggested": suggest_column_name(col)[0]}
            for col in cols
        ]
    return result

@app.post("/rename_column")
def rename_column(req: RenameRequest):
    update_column_name(req.raw, req.readable)
    return {"message": f"Renamed '{req.raw}' to '{req.readable}'"}

@app.post("/rename_table")
def rename_table_api(req: TableRenameRequest):
    update_table_name(req.raw, req.readable)
    return {"message": f"Renamed table '{req.raw}' to '{req.readable}'"}

