"""Neo4j graph store — persist and query the knowledge graph."""

from __future__ import annotations

from neo4j import GraphDatabase, Driver

import sys

from rich.console import Console

from ...config import settings
from ..models import Edge, NodeType, ProjectIndex, SymbolNode

_is_tty = hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
console = Console(stderr=True, force_terminal=_is_tty, highlight=False)


class GraphStore:
    """Manages Neo4j connection and graph operations."""

    def __init__(self, uri: str = "", user: str = "", password: str = ""):
        self._uri = uri or settings.neo4j_uri
        self._user = user or settings.neo4j_user
        self._password = password or settings.neo4j_password
        self._driver: Driver | None = None

    def connect(self):
        self._driver = GraphDatabase.driver(self._uri, auth=(self._user, self._password))
        self._driver.verify_connectivity()
        console.print("[green]✓[/green] Connected to Neo4j")

    def close(self):
        if self._driver:
            self._driver.close()

    @property
    def driver(self) -> Driver:
        if not self._driver:
            self.connect()
        return self._driver

    # ─── Schema ─────────────────────────────────────────────────

    def setup_indexes(self):
        """Create indexes for fast lookup."""
        queries = [
            "CREATE INDEX IF NOT EXISTS FOR (n:Symbol) ON (n.uid)",
            "CREATE INDEX IF NOT EXISTS FOR (n:Symbol) ON (n.name)",
            "CREATE INDEX IF NOT EXISTS FOR (n:Symbol) ON (n.file_path)",
            "CREATE INDEX IF NOT EXISTS FOR (n:Symbol) ON (n.node_type)",
            "CREATE INDEX IF NOT EXISTS FOR (n:File) ON (n.uid)",
            "CREATE INDEX IF NOT EXISTS FOR (n:Folder) ON (n.uid)",
        ]
        with self.driver.session() as session:
            for q in queries:
                session.run(q)

    # ─── Write ──────────────────────────────────────────────────

    def clear_project(self, root_path: str):
        """Delete all nodes/edges for a project in batches to avoid memory spikes."""
        with self.driver.session() as session:
            while True:
                result = session.run(
                    "MATCH (n) WHERE n.project = $project "
                    "WITH n LIMIT 5000 DETACH DELETE n "
                    "RETURN count(*) AS deleted",
                    project=root_path,
                )
                deleted = result.single()["deleted"]
                if deleted == 0:
                    break

    def store_index(self, index: ProjectIndex):
        """Store a complete ProjectIndex into Neo4j."""
        self.setup_indexes()
        self.clear_project(index.root_path)

        BATCH_SIZE = 500

        # Batch create nodes using UNWIND
        node_batches: dict[str, list[dict]] = {}
        for node in index.nodes:
            label = node.node_type.value
            node_batches.setdefault(label, []).append({
                "uid": node.uid,
                "name": node.name,
                "node_type": node.node_type.value,
                "file_path": node.file_path,
                "line_start": node.line_start,
                "line_end": node.line_end,
                "language": node.language or "",
                "docstring": node.docstring or "",
                "signature": node.signature or "",
                "project": index.root_path,
            })

        with self.driver.session() as session:
            for label, items in node_batches.items():
                for i in range(0, len(items), BATCH_SIZE):
                    batch = items[i : i + BATCH_SIZE]
                    session.run(
                        f"UNWIND $batch AS row "
                        f"CREATE (n:{label}:Symbol) SET n = row",
                        batch=batch,
                    )

        # Batch create edges using UNWIND, grouped by type
        edge_batches: dict[str, list[dict]] = {}
        for edge in index.edges:
            etype = edge.edge_type.value
            edge_batches.setdefault(etype, []).append({
                "src": edge.source_uid,
                "tgt": edge.target_uid,
            })

        with self.driver.session() as session:
            for etype, items in edge_batches.items():
                for i in range(0, len(items), BATCH_SIZE):
                    batch = items[i : i + BATCH_SIZE]
                    session.run(
                        f"UNWIND $batch AS row "
                        f"MATCH (a:Symbol {{uid: row.src}}), (b:Symbol {{uid: row.tgt}}) "
                        f"CREATE (a)-[:{etype}]->(b)",
                        batch=batch,
                    )

        stats = index.stats
        console.print(
            f"[green]✓[/green] Stored {stats['nodes']} nodes, {stats['edges']} edges in Neo4j"
        )

    # ─── Read ───────────────────────────────────────────────────

    def get_symbol(self, uid: str) -> dict | None:
        """Get a single symbol by UID."""
        with self.driver.session() as session:
            result = session.run(
                "MATCH (n:Symbol {uid: $uid}) RETURN properties(n) AS props",
                uid=uid,
            )
            record = result.single()
            return dict(record["props"]) if record else None

    def find_symbols(self, name: str) -> list[dict]:
        """Find symbols by name."""
        with self.driver.session() as session:
            result = session.run(
                "MATCH (n:Symbol {name: $name}) "
                "WHERE n.node_type IN ['Function', 'Class', 'Method'] "
                "RETURN properties(n) AS props",
                name=name,
            )
            return [dict(r["props"]) for r in result]

    def get_context(self, uid: str) -> dict:
        """360° context view of a symbol — incoming and outgoing edges."""
        with self.driver.session() as session:
            # Outgoing
            out_result = session.run(
                "MATCH (n:Symbol {uid: $uid})-[r]->(m:Symbol) "
                "RETURN type(r) AS rel, properties(m) AS target",
                uid=uid,
            )
            outgoing: dict[str, list] = {}
            for record in out_result:
                rel = record["rel"]
                outgoing.setdefault(rel, []).append(dict(record["target"]))

            # Incoming
            in_result = session.run(
                "MATCH (n:Symbol {uid: $uid})<-[r]-(m:Symbol) "
                "RETURN type(r) AS rel, properties(m) AS source",
                uid=uid,
            )
            incoming: dict[str, list] = {}
            for record in in_result:
                rel = record["rel"]
                incoming.setdefault(rel, []).append(dict(record["source"]))

        return {"outgoing": outgoing, "incoming": incoming}

    def get_impact(self, uid: str, direction: str = "downstream", depth: int = 3) -> dict:
        """Blast radius analysis — which symbols are affected."""
        # Single query with variable-length path instead of N separate queries
        path_pattern = f"-[*1..{depth}]->" if direction == "downstream" else f"<-[*1..{depth}]-"

        with self.driver.session() as session:
            result = session.run(
                f"MATCH path = (start:Symbol {{uid: $uid}}){path_pattern}(end:Symbol) "
                f"WHERE end.node_type IN ['Function', 'Class', 'Method'] "
                f"RETURN DISTINCT end.uid AS uid, end.name AS name, "
                f"end.file_path AS file_path, end.node_type AS node_type, "
                f"length(path) AS dist",
                uid=uid,
            )
            results_by_depth: dict[int, list] = {}
            for r in result:
                d = r["dist"]
                results_by_depth.setdefault(d, []).append({
                    "name": r["name"], "file_path": r["file_path"], "node_type": r["node_type"],
                })

        # Risk assessment
        total = sum(len(v) for v in results_by_depth.values())
        if total > 20:
            risk = "CRITICAL"
        elif total > 10:
            risk = "HIGH"
        elif total > 3:
            risk = "MEDIUM"
        else:
            risk = "LOW"

        return {
            "target": uid,
            "direction": direction,
            "risk": risk,
            "total_affected": total,
            "by_depth": {
                f"d{d}": [
                    {"name": n["name"], "file": n["file_path"], "type": n["node_type"]}
                    for n in nodes
                ]
                for d, nodes in results_by_depth.items()
            },
        }

    def get_all_symbols(self, project: str | None = None) -> list[dict]:
        """Get all symbols (for search index building)."""
        query = (
            "MATCH (n:Symbol) WHERE n.node_type IN ['Function', 'Class', 'Method'] "
        )
        if project:
            query += "AND n.project = $project "
        query += "RETURN properties(n) AS props"

        with self.driver.session() as session:
            result = session.run(query, project=project) if project else session.run(query)
            return [dict(r["props"]) for r in result]

    def get_stats(self, project: str | None = None) -> dict:
        """Get graph statistics."""
        with self.driver.session() as session:
            where = "WHERE n.project = $project" if project else ""
            params = {"project": project} if project else {}

            node_count = session.run(
                f"MATCH (n:Symbol) {where} RETURN count(n) AS c", **params
            ).single()["c"]

            # Use directed --> to avoid double-counting undirected edges
            edge_count = session.run(
                f"MATCH (n:Symbol)-[r]->(m:Symbol) {where} RETURN count(r) AS c", **params
            ).single()["c"]

            return {"nodes": node_count, "edges": edge_count}

    # ─── Clusters & Processes ───────────────────────────────────

    def get_clusters(self, project: str | None = None) -> list[dict]:
        """Get all functional clusters with member counts."""
        with self.driver.session() as session:
            # Single query: fetch clusters + their members via OPTIONAL MATCH
            result = session.run(
                "MATCH (c:Cluster:Symbol) "
                "WHERE c.uid STARTS WITH 'cluster:' "
                "OPTIONAL MATCH (m:Symbol)-[:DEFINED_IN]->(c) "
                "WHERE m.node_type IN ['Function', 'Class', 'Method'] "
                "RETURN properties(c) AS props, "
                "collect({uid: m.uid, name: m.name, type: m.node_type, file: m.file_path}) AS members "
                "ORDER BY c.name"
            )
            clusters = []
            for r in result:
                props = dict(r["props"])
                # Filter out null members (from OPTIONAL MATCH with no matches)
                props["members"] = [m for m in r["members"] if m.get("uid") is not None]
                clusters.append(props)
            return clusters

    def get_processes(self, project: str | None = None) -> list[dict]:
        """Get all execution flows."""
        with self.driver.session() as session:
            result = session.run(
                "MATCH (n:Process:Symbol) "
                "WHERE n.uid STARTS WITH 'process:' "
                "RETURN properties(n) AS props"
            )
            return [dict(r["props"]) for r in result]

    def get_process_detail(self, process_uid: str) -> dict | None:
        """Get a specific process with its steps."""
        with self.driver.session() as session:
            result = session.run(
                "MATCH (n:Symbol {uid: $uid}) RETURN properties(n) AS props",
                uid=process_uid,
            )
            record = result.single()
            if not record:
                return None

            props = dict(record["props"])

            # Get step symbols
            steps_result = session.run(
                "MATCH (m:Symbol)-[r:DEFINED_IN]->(p:Symbol {uid: $uid}) "
                "RETURN m.uid AS uid, m.name AS name, m.node_type AS type, "
                "m.file_path AS file, m.signature AS signature",
                uid=process_uid,
            )
            props["step_symbols"] = [dict(s) for s in steps_result]
            return props
