from db import inspector, get_columns
from ml import best_paths
from table_rename import rename_table
from db_master import get_master_mapping

def find_join_key(t1, t2):
    fks = inspector.get_foreign_keys(t1)
    for fk in fks:
        if fk['referred_table'] == t2:
            return fk['constrained_columns'][0], fk['referred_columns'][0]
    shared = list(set(get_columns()[t1]) & set(get_columns()[t2]))
    return (shared[0], shared[0]) if shared else (None, None)

def merge_paths(paths):
    edges = []
    for path in paths:
        for i in range(len(path) - 1):
            edge = (path[i], path[i+1])
            if edge not in edges:
                edges.append(edge)
    return edges

def generate_subquery_sql(config: dict) -> str:
    return "(" + generate_multi_join_sql(**config).strip() + ")"

def generate_multi_join_sql(
    source,
    targets,
    column_filters=None,
    where_filters=None,
    order_by=None,
    limit=None,
    offset=None,
    group_by=None,
    aggregations=None,
    having_filters=None,
    subqueries=None,
    master_mappings=None  # NEW: {table: {col: (master_table, key_col, value_col)}}
):
    paths = best_paths(source, targets)
    edges = merge_paths(paths)

    column_filters = column_filters or {}
    aggregations = aggregations or {}
    subqueries = subqueries or {}
    master_mappings = master_mappings or {}
    selected_cols = []

    subquery_aliases = {}
    for alias, config in subqueries.items():
        subquery_sql = generate_subquery_sql(config)
        subquery_aliases[alias] = subquery_sql

    # Load master mappings
    master_lookup = {}
    for table, col_map in master_mappings.items():
        for col, (m_table, key_col, value_col) in col_map.items():
            master_lookup[(table, col)] = get_master_mapping(m_table, key_col, value_col)

    def enrich_value(table, col, val):
        return master_lookup.get((table, col), {}).get(val, val)

    for table in [source] + targets:
        readable_table = rename_table(table)
        cols = column_filters.get(table, ["*"])
        for col in cols:
            if table in aggregations and col in aggregations[table]:
                func = aggregations[table][col].upper()
                selected_cols.append(f"{func}({readable_table}.{col}) AS {func.lower()}_{col}")
            else:
                selected_cols.append(f"{readable_table}.{col}")
    select_clause = "SELECT " + ", ".join(selected_cols)

    if source in subquery_aliases:
        sql = f"{select_clause}\nFROM {subquery_aliases[source]} AS {source}"
    else:
        sql = f"{select_clause}\nFROM {rename_table(source)}"

    for t1, t2 in edges:
        col1, col2 = find_join_key(t1, t2)
        if not col1 or not col2:
            return f"-- Cannot determine join key between {t1} and {t2}"
        left = subquery_aliases.get(t1, rename_table(t1))
        right = subquery_aliases.get(t2, rename_table(t2))
        sql += f"\nJOIN {right} ON {left}.{col1} = {right}.{col2}"

    def parse_condition(table, col, rule, use_alias=False):
        op = rule.get("op", "=").upper()
        val = rule.get("value")
        val = enrich_value(table, col, val)
        if use_alias and table in aggregations and col in aggregations[table]:
            col_expr = f"{aggregations[table][col].lower()}_{col}"
        else:
            col_expr = f"{rename_table(table)}.{col}"
        if op == "IN" and isinstance(val, list):
            val_str = "(" + ", ".join(f"'{v}'" if isinstance(v, str) else str(v) for v in val) + ")"
        elif op == "BETWEEN" and isinstance(val, list) and len(val) == 2:
            val_str = f"{val[0]} AND {val[1]}"
        else:
            val_str = f"'{val}'" if isinstance(val, str) else str(val)
        return f"{col_expr} {op} {val_str}"

    def build_conditions(filters, logic="AND", use_alias=False):
        clauses = []
        for table, rules in filters.items():
            for col, rule in rules.items():
                if isinstance(rule, dict) and "op" in rule:
                    clauses.append(parse_condition(table, col, rule, use_alias))
                elif isinstance(rule, dict):
                    nested = build_conditions({table: rule}, logic, use_alias)
                    clauses.append(f"({nested})")
        return f" {logic} ".join(clauses)

    if where_filters:
        sql += "\nWHERE " + build_conditions(where_filters)

    if group_by:
        group_cols = [f"{rename_table(tbl)}.{col}" for tbl, col in group_by]
        sql += "\nGROUP BY " + ", ".join(group_cols)

    if having_filters:
        sql += "\nHAVING " + build_conditions(having_filters, use_alias=True)

    if order_by:
        order_clauses = [f"{rename_table(tbl)}.{col} {direction.upper()}" for tbl, col, direction in order_by]
        sql += "\nORDER BY " + ", ".join(order_clauses)

    if limit:
        sql += f"\nLIMIT {limit}"
    if offset:
        sql += f"\nOFFSET {offset}"

    return sql