"""Clustering — Phase 4: Community Detection.

Groups related symbols into functional clusters using a simplified
Louvain-style modularity algorithm (no external dependency needed).
Produces Cluster nodes + MEMBER_OF edges.
"""

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass

from ..models import Edge, EdgeType, NodeType, SymbolNode
from ._console import console


@dataclass
class Cluster:
    """A functional cluster of related symbols."""
    id: str
    name: str
    members: list[str]  # UIDs
    cohesion: float     # 0-1, internal edge density


_SYMBOL_TYPES = frozenset((NodeType.FUNCTION, NodeType.CLASS, NodeType.METHOD))
_RELEVANT_EDGE_TYPES = frozenset((EdgeType.CALLS, EdgeType.IMPORTS, EdgeType.HAS_METHOD, EdgeType.EXTENDS))


def _build_adjacency(nodes: list[SymbolNode], edges: list[Edge]) -> dict[str, set[str]]:
    """Build undirected adjacency list from CALLS, IMPORTS, HAS_METHOD edges."""
    symbol_uids = {n.uid for n in nodes if n.node_type in _SYMBOL_TYPES}
    adj: dict[str, set[str]] = defaultdict(set)

    relevant_types = _RELEVANT_EDGE_TYPES
    for e in edges:
        if e.edge_type in relevant_types and e.source_uid in symbol_uids and e.target_uid in symbol_uids:
            adj[e.source_uid].add(e.target_uid)
            adj[e.target_uid].add(e.source_uid)

    return adj


def _label_propagation(adj: dict[str, set[str]], max_iter: int = 20) -> dict[str, int]:
    """Simple label propagation for community detection.

    Each node starts with its own label. Iteratively, each node adopts
    the most common label among its neighbors. Converges to communities.
    """
    # Collect all node UIDs in a single pass (adj keys + their neighbors)
    all_nodes: set[str] = set(adj.keys())
    for neighbors in adj.values():
        all_nodes.update(neighbors)

    # Initialize: each node gets its own label
    labels: dict[str, int] = {uid: i for i, uid in enumerate(all_nodes)}

    # Sort once — deterministic order preserved across iterations
    sorted_nodes = sorted(all_nodes)

    for iteration in range(max_iter):
        changed = False
        for uid in sorted_nodes:
            neighbors = adj.get(uid, set())
            if not neighbors:
                continue

            # Count neighbor labels using inline counting (faster than defaultdict)
            label_counts: dict[int, int] = {}
            for n in neighbors:
                lbl = labels[n]
                label_counts[lbl] = label_counts.get(lbl, 0) + 1

            # Pick most common label
            best_label = max(label_counts, key=label_counts.__getitem__)
            if labels[uid] != best_label:
                labels[uid] = best_label
                changed = True

        if not changed:
            break

    return labels


def _compute_cohesion(member_set: set[str], adj: dict[str, set[str]]) -> float:
    """Compute cluster cohesion: ratio of internal edges to possible edges."""
    n = len(member_set)
    if n <= 1:
        return 1.0

    # Count internal edges (each counted once via intersection)
    internal_edges = sum(
        len(adj.get(uid, set()) & member_set)
        for uid in member_set
    ) // 2  # undirected → halve

    max_edges = n * (n - 1) // 2
    return round(internal_edges / max_edges, 3) if max_edges > 0 else 0.0


def _name_cluster(members: list[str], node_map: dict[str, SymbolNode]) -> str:
    """Generate a human-readable name for a cluster based on file paths and symbols."""
    from collections import Counter

    dir_counter: Counter[str] = Counter()

    for uid in members:
        node = node_map.get(uid)
        if not node:
            continue
        parts = node.file_path.split("/")
        if len(parts) > 1:
            d = parts[-2] if parts[-2] != "src" else parts[-1]
        else:
            d = parts[0].replace(".py", "").replace(".ts", "").replace(".js", "")
        dir_counter[d] += 1

    if dir_counter:
        primary = dir_counter.most_common(1)[0][0]
        name = primary.replace("_", " ").replace("-", " ").title()
        return f"{name} Module"

    return f"Cluster ({len(members)} symbols)"


def compute_clusters(
    nodes: list[SymbolNode], edges: list[Edge]
) -> tuple[list[SymbolNode], list[Edge]]:
    """Phase 4: Detect functional clusters in the code graph.

    Returns:
        (cluster_nodes, member_edges) — new nodes and MEMBER_OF edges
    """
    # Build adjacency from relevant symbols
    adj = _build_adjacency(nodes, edges)

    if not adj:
        console.print("  [yellow]![/yellow] Clustering: no relationships found, skipping")
        return [], []

    # Run community detection
    labels = _label_propagation(adj)

    # Group by label
    communities: dict[int, list[str]] = defaultdict(list)
    for uid, label in labels.items():
        communities[label].append(uid)

    # Filter out single-member clusters
    clusters_raw = {
        label: members
        for label, members in communities.items()
        if len(members) >= 2
    }

    # Build results
    node_map = {n.uid: n for n in nodes}
    cluster_nodes: list[SymbolNode] = []
    member_edges: list[Edge] = []

    for i, (label, members) in enumerate(sorted(clusters_raw.items(), key=lambda x: -len(x[1]))):
        cluster_id = f"cluster:{i}"
        name = _name_cluster(members, node_map)
        cohesion = _compute_cohesion(set(members), adj)

        cluster_node = SymbolNode(
            uid=cluster_id,
            name=name,
            node_type=NodeType.CLUSTER,
            file_path="",
            line_start=0,
            line_end=0,
            properties={"is_cluster": True, "cohesion": cohesion, "size": len(members)},
        )
        cluster_nodes.append(cluster_node)

        for member_uid in members:
            member_edges.append(Edge(
                source_uid=member_uid,
                target_uid=cluster_id,
                edge_type=EdgeType.DEFINED_IN,  # Reuse for membership
                properties={"membership": "cluster"},
            ))

    console.print(
        f"  [green]✓[/green] Clustering: {len(cluster_nodes)} clusters "
        f"from {sum(len(m) for m in clusters_raw.values())} symbols"
    )

    return cluster_nodes, member_edges
