# main.py

from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import List, Dict, Any, Optional

from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.messages import HumanMessage, SystemMessage

from sqlalchemy import create_engine, text
from sqlalchemy.engine import URL

import json
import re
from pathlib import Path

app = FastAPI(
    title="Agriculture Chat-to-SQL API (Local Ollama)",
    description="Text-to-SQL using local llama3.2:3b-instruct via Ollama",
    version="0.1.1"  # bumped version
)

# ────────────────────────────────────────────────
# Configuration
# ────────────────────────────────────────────────

db_url = URL.create(
    drivername="postgresql+psycopg2",
    username="postgres",
    password="gQube1_#@96740",
    host="139.59.87.82",
    port=5432,
    database="four_s_india"
)

SCHEMA_CHUNKS_PATH = Path("schema_chunks.json")
API_KEY = "testing-2026-kolkata-llama-api-key"

security = HTTPBearer()

def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.credentials != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    return credentials.credentials

# ────────────────────────────────────────────────
# Globals
# ────────────────────────────────────────────────

engine = None
db = None
vectorstore = None
llm = None

def get_engine():
    global engine
    if engine is None:
        engine = create_engine(db_url, pool_pre_ping=True)
    return engine

def get_db():
    global db
    if db is None:
        db = SQLDatabase(get_engine())
    return db

def get_vectorstore():
    global vectorstore
    if vectorstore is None:
        if not SCHEMA_CHUNKS_PATH.exists():
            raise RuntimeError("schema_chunks.json missing – run ingestion first")

        with open(SCHEMA_CHUNKS_PATH, "r", encoding="utf-8") as f:
            chunks = json.load(f)

        texts = [c["text"] for c in chunks]
        metadatas = [c.get("metadata", {}) for c in chunks]

        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

        vectorstore = Chroma.from_texts(
            texts=texts,
            embedding=embeddings,
            metadatas=metadatas,
            collection_name="agri_schema_local"
        )
    return vectorstore

def get_llm():
    global llm
    if llm is None:
        llm = ChatOllama(
            model="llama3.2:3b-instruct-q4_0",   # ← updated to instruct
            temperature=0.0,
            num_ctx=4096,                        # increased for better context
            num_predict=1024,                    # more output tokens allowed
            base_url="http://localhost:11434"
        )
    return llm

# ────────────────────────────────────────────────
# Improved SQL Extraction (more robust patterns)
# ────────────────────────────────────────────────

def extract_sql(text: str) -> Optional[str]:
    text = text.strip()

    # 1. ```sql ... ```
    match = re.search(r"```sql\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()

    # 2. Any ``` block starting with SELECT
    match = re.search(r"```\s*(SELECT\s+.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()

    # 3. Raw SELECT (no backticks)
    match = re.search(r"(SELECT\s+.*?)(?:\n{2,}|$|;)", text, re.DOTALL | re.IGNORECASE)
    if match:
        sql = match.group(1).strip()
        if sql.upper().startswith("SELECT"):
            return sql.rstrip(';') + ';'

    # 4. Last resort: any line starting with SELECT
    lines = text.splitlines()
    for line in lines:
        stripped = line.strip()
        if stripped.upper().startswith("SELECT"):
            return stripped.rstrip(';') + ';'

    return None

# ────────────────────────────────────────────────
# Execute SQL
# ────────────────────────────────────────────────

def run_sql(sql: str, max_rows: int) -> List[Dict[str, Any]]:
    with get_engine().connect() as conn:
        result = conn.execute(text(sql))
        columns = list(result.keys())
        rows = []
        for i, row in enumerate(result):
            if i >= max_rows:
                break
            rows.append(dict(zip(columns, row)))
        return rows

# ────────────────────────────────────────────────
# Models
# ────────────────────────────────────────────────

class QueryRequest(BaseModel):
    question: str
    max_rows: Optional[int] = 150

class QueryResponse(BaseModel):
    question: str
    generated_sql: Optional[str] = None
    rows: Optional[List[Dict[str, Any]]] = None
    row_count: Optional[int] = None
    natural_answer: Optional[str] = None
    error: Optional[str] = None

# ────────────────────────────────────────────────
# Endpoints
# ────────────────────────────────────────────────

@app.get("/health")
async def health():
    return {"status": "ok", "llm": "ollama-llama3.2:3b-instruct-q4_0"}

@app.get("/tables")
async def list_tables(_: str = Depends(verify_api_key)):
    tables = sorted(get_db().get_usable_table_names())
    return {"tables": tables}

@app.get("/table/{table_name}")
async def get_table_info(table_name: str, _: str = Depends(verify_api_key)):
    info = get_db().get_table_info_no_throw([table_name])
    return {"table": table_name, "info": info}

@app.post("/query", response_model=QueryResponse)
async def ask(req: QueryRequest, _: str = Depends(verify_api_key)):
    try:
        # 1. Retrieve relevant schema context
        docs = get_vectorstore().similarity_search(req.question, k=6)  # increased k
        context = "\n\n".join(d.page_content for d in docs)

        # 2. Add special rule for state column
        state_hint = (
            "\n\nSPECIAL RULE FOR STATE: "
            "Column 'state' is character varying(255) → string/text. "
            "ALWAYS use quotes: state = 'West Bengal' or state = '4' "
            "(never state = 4 without quotes). "
            "Do NOT use numeric comparison without CAST or quotes."
        )
        context += state_hint

        # 3. Very strong anti-hallucination prompt
        system_prompt = """You are a strict PostgreSQL SQL generator.
Output ONLY the SQL query — nothing else. No explanation, no markdown, no code blocks, no text before or after the SQL.

STRICT RULES - YOU MUST FOLLOW:
- Use ONLY tables that appear EXACTLY in the schema context below.
- NEVER invent or use tables like 'farmer_csv_records', 'users_csv', or any name not listed.
- If no suitable table exists → output exactly: INSUFFICIENT SCHEMA
- For counting farmers → MUST use 'upload_farmer' or 'agri_users' tables.
- Column 'state' is string → always use quotes: 'West Bengal' or '4'
- Use COUNT(*) for row counts
- End every query with semicolon ;
- Never INSERT, UPDATE, DELETE, DROP"""

        user_prompt = f"""Schema context:
{context}

Question: {req.question}

SQL query:"""

        # 4. Call LLM
        response = get_llm().invoke([
            SystemMessage(content=system_prompt),
            HumanMessage(content=user_prompt)
        ])

        raw_output = response.content.strip()

        # 5. Extract SQL
        sql = extract_sql(raw_output)

        # Fallback: if output starts with SELECT → accept it
        if not sql and raw_output.upper().startswith("SELECT"):
            sql = raw_output.rstrip(';') + ';'

        if not sql:
            return QueryResponse(
                question=req.question,
                natural_answer=raw_output,
                error="Model did not produce valid SQL. Try rephrasing or check schema."
            )

        # 6. Execute
        try:
            rows = run_sql(sql, req.max_rows)
        except Exception as ex:
            return QueryResponse(
                question=req.question,
                generated_sql=sql,
                error=f"SQL execution failed: {str(ex)}"
            )

        return QueryResponse(
            question=req.question,
            generated_sql=sql,
            rows=rows,
            row_count=len(rows)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8001, reload=True)