"""Reference resolver — Phase 3: Resolution.

Resolves cross-file references: imports, function calls, inheritance.
Converts 'unresolved:XYZ' edges into proper node-to-node edges.
"""

from __future__ import annotations

import re
from pathlib import Path

from ..models import Edge, EdgeType, FileNode, NodeType, SymbolNode
from ._console import console

# Precompiled regex patterns (avoid recompiling per line/call)
_PY_FROM_IMPORT_RE = re.compile(r"^from\s+([\w.]+)\s+import\s+(.+)")
_PY_IMPORT_RE = re.compile(r"^import\s+([\w.]+)")
_JS_ES6_IMPORT_RE = re.compile(r"""import\s+.*?from\s+['"]([^'"]+)['"]""")
_JS_REQUIRE_RE = re.compile(r"""require\s*\(\s*['"]([^'"]+)['"]\s*\)""")

# Frozen type sets (avoid creating tuples on each call)
_SKIP_TYPES = frozenset((NodeType.FILE, NodeType.FOLDER))
_CALLABLE_TYPES = frozenset((NodeType.FUNCTION, NodeType.METHOD, NodeType.CLASS))
_CALLER_TYPES = frozenset((NodeType.FUNCTION, NodeType.METHOD))


def _build_symbol_index(nodes: list[SymbolNode]) -> dict[str, list[SymbolNode]]:
    """Build name -> [nodes] index for quick lookup."""
    index: dict[str, list[SymbolNode]] = {}
    for node in nodes:
        if node.node_type in _SKIP_TYPES:
            continue
        index.setdefault(node.name, []).append(node)
    return index


def _extract_python_imports(source: str, file_path: str) -> list[Edge]:
    """Extract import statements from Python source."""
    edges: list[Edge] = []
    file_uid = f"file:{file_path}"

    for line in source.split("\n"):
        line = line.strip()
        # from X import Y
        m = _PY_FROM_IMPORT_RE.match(line)
        if m:
            module = m.group(1)
            edges.append(Edge(
                source_uid=file_uid,
                target_uid=f"import:{module}",
                edge_type=EdgeType.IMPORTS,
                properties={"module": module, "names": m.group(2).strip()},
            ))
            continue
        # import X
        m = _PY_IMPORT_RE.match(line)
        if m:
            module = m.group(1)
            edges.append(Edge(
                source_uid=file_uid,
                target_uid=f"import:{module}",
                edge_type=EdgeType.IMPORTS,
                properties={"module": module},
            ))

    return edges


def _extract_js_imports(source: str, file_path: str) -> list[Edge]:
    """Extract import/require statements from JS/TS source."""
    edges: list[Edge] = []
    file_uid = f"file:{file_path}"

    # ES6 imports: import { X } from 'Y'  /  import X from 'Y'
    for m in _JS_ES6_IMPORT_RE.finditer(source):
        module = m.group(1)
        edges.append(Edge(
            source_uid=file_uid,
            target_uid=f"import:{module}",
            edge_type=EdgeType.IMPORTS,
            properties={"module": module},
        ))

    # require: const X = require('Y')
    for m in _JS_REQUIRE_RE.finditer(source):
        module = m.group(1)
        edges.append(Edge(
            source_uid=file_uid,
            target_uid=f"import:{module}",
            edge_type=EdgeType.IMPORTS,
            properties={"module": module},
        ))

    return edges


def _extract_call_edges(
    symbols: list[SymbolNode], symbol_index: dict[str, list[SymbolNode]]
) -> list[Edge]:
    """Extract CALLS edges by scanning function bodies for known symbol names."""
    edges: list[Edge] = []
    seen: set[tuple[str, str]] = set()

    # Only look for calls from functions/methods
    callers = [s for s in symbols if s.node_type in _CALLER_TYPES]

    # Build set of callable names (2+ chars to avoid noise)
    callable_names = set()
    for name, syms in symbol_index.items():
        if len(name) < 2:
            continue
        for s in syms:
            if s.node_type in _CALLABLE_TYPES:
                callable_names.add(name)
                break

    if not callable_names:
        return edges

    # Build combined regex patterns — chunk into groups of MAX_REGEX_NAMES
    # to avoid catastrophic regex size / compile time on huge codebases
    MAX_REGEX_NAMES = 500
    sorted_names = sorted(callable_names, key=len, reverse=True)
    patterns: list[re.Pattern] = []
    for i in range(0, len(sorted_names), MAX_REGEX_NAMES):
        chunk = sorted_names[i : i + MAX_REGEX_NAMES]
        patterns.append(re.compile(
            r"(?<![.\w])(" + "|".join(re.escape(n) for n in chunk) + r")\s*\("
        ))

    for caller in callers:
        if not caller.body_text:
            continue

        # Run all regex patterns, collect matches
        all_matches: set[str] = set()
        for pat in patterns:
            all_matches.update(pat.findall(caller.body_text))
        if not all_matches:
            continue

        all_matches.discard(caller.name)  # Skip self-recursion

        for name in all_matches:
            targets = symbol_index.get(name, [])
            if not targets:
                continue

            # Prefer same-file target
            target = None
            for t in targets:
                if t.file_path == caller.file_path:
                    target = t
                    break
            if target is None:
                target = targets[0]

            if (caller.uid, target.uid) not in seen:
                seen.add((caller.uid, target.uid))
                edges.append(Edge(
                    source_uid=caller.uid,
                    target_uid=target.uid,
                    edge_type=EdgeType.CALLS,
                ))

    return edges


def _resolve_import_edges(
    import_edges: list[Edge], files: list[FileNode]
) -> list[Edge]:
    """Resolve import:module references to actual file:path nodes."""
    resolved: list[Edge] = []

    # Build module -> file mapping
    file_map: dict[str, str] = {}
    # Also build suffix index for partial matching: last_segment -> file_uid
    suffix_map: dict[str, str] = {}

    for f in files:
        file_uid = f"file:{f.relative_path}"

        # Python: src/auth/login.py -> src.auth.login
        py_module = f.relative_path.replace("/", ".").replace("\\", ".")
        if py_module.endswith(".py"):
            py_module = py_module[:-3]
        file_map[py_module] = file_uid

        # JS/TS: ./auth/login or ../utils
        stem = f.relative_path
        for ext in (".ts", ".tsx", ".js", ".jsx"):
            if stem.endswith(ext):
                stem = stem[: -len(ext)]
                break
        file_map[stem] = file_uid
        file_map[f"./{stem}"] = file_uid

        # Suffix: last path segment for partial matching
        last_segment = stem.split("/")[-1]
        suffix_map[last_segment] = file_uid

    for edge in import_edges:
        module = edge.properties.get("module", "")
        # Try direct match (O(1))
        target_uid = file_map.get(module)
        # Try suffix match (O(1) instead of O(n) linear scan)
        if not target_uid:
            last = module.split(".")[-1] if "." in module else module.split("/")[-1]
            target_uid = suffix_map.get(last)
        if target_uid:
            resolved.append(Edge(
                source_uid=edge.source_uid,
                target_uid=target_uid,
                edge_type=EdgeType.IMPORTS,
                properties=edge.properties,
            ))

    return resolved


def _resolve_extends_edges(
    edges: list[Edge], symbol_index: dict[str, list[SymbolNode]]
) -> list[Edge]:
    """Resolve unresolved:ClassName EXTENDS edges to actual class nodes."""
    resolved: list[Edge] = []

    for edge in edges:
        if not edge.target_uid.startswith("unresolved:"):
            continue
        base_name = edge.target_uid.split(":", 1)[1]
        targets = symbol_index.get(base_name, [])
        for t in targets:
            if t.node_type == NodeType.CLASS:
                resolved.append(Edge(
                    source_uid=edge.source_uid,
                    target_uid=t.uid,
                    edge_type=EdgeType.EXTENDS,
                ))
                break

    return resolved


def resolve_references(
    files: list[FileNode],
    nodes: list[SymbolNode],
    edges: list[Edge],
) -> list[Edge]:
    """Phase 3: Resolve all cross-file references.

    Returns new edges to add (import, call, extends resolutions).
    """
    symbol_index = _build_symbol_index(nodes)
    new_edges: list[Edge] = []

    # 1. Extract and resolve imports (use cached source_text from Phase 2)
    import_edges: list[Edge] = []
    for f in files:
        source = f.source_text
        if not source:
            try:
                with open(f.path, "r", encoding="utf-8", errors="replace") as fh:
                    source = fh.read()
            except OSError:
                continue

        if f.language == "python":
            import_edges.extend(_extract_python_imports(source, f.relative_path))
        elif f.language in ("javascript", "typescript"):
            import_edges.extend(_extract_js_imports(source, f.relative_path))

    resolved_imports = _resolve_import_edges(import_edges, files)
    new_edges.extend(resolved_imports)

    # 2. Resolve EXTENDS edges (unresolved:ClassName -> actual class)
    unresolved = [e for e in edges if e.target_uid.startswith("unresolved:")]
    resolved_extends = _resolve_extends_edges(unresolved, symbol_index)
    new_edges.extend(resolved_extends)

    # 3. Extract CALLS edges from function bodies
    all_symbols = [n for n in nodes if n.node_type not in _SKIP_TYPES]
    call_edges = _extract_call_edges(all_symbols, symbol_index)
    new_edges.extend(call_edges)

    console.print(
        f"  [green]✓[/green] Resolution: {len(resolved_imports)} imports, "
        f"{len(call_edges)} calls, {len(resolved_extends)} extends"
    )
    return new_edges
