"""Code parser — Phase 2: Parsing.

Uses Tree-sitter to extract functions, classes, methods, and variables
from source code files. Produces SymbolNode objects.
"""

from __future__ import annotations

import os
from pathlib import Path

from ..models import EdgeType, Edge, FileNode, NodeType, SymbolNode
from ._console import console

# Tree-sitter language modules (lazy loaded)
_languages: dict = {}


def _get_language(lang_name: str):
    """Lazy-load tree-sitter language."""
    if lang_name in _languages:
        return _languages[lang_name]

    import tree_sitter

    if lang_name == "python":
        import tree_sitter_python as tspython
        language = tree_sitter.Language(tspython.language())
    elif lang_name == "javascript":
        import tree_sitter_javascript as tsjs
        language = tree_sitter.Language(tsjs.language())
    elif lang_name == "typescript":
        import tree_sitter_typescript as tsts
        language = tree_sitter.Language(tsts.language_typescript())
    else:
        return None

    _languages[lang_name] = language
    return language


def _read_file(path: str) -> bytes | None:
    try:
        with open(path, "rb") as f:
            return f.read()
    except (OSError, UnicodeDecodeError):
        return None


def _extract_text(source: bytes, node) -> str:
    """Extract source text for a tree-sitter node."""
    return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace")


def _get_docstring(source: bytes, body_node) -> str:
    """Try to extract a docstring from the first statement of a body."""
    if body_node is None or body_node.child_count == 0:
        return ""
    first = body_node.children[0]
    if first.type == "expression_statement" and first.child_count > 0:
        expr = first.children[0]
        if expr.type == "string":
            text = _extract_text(source, expr).strip("\"'")
            return text[:500]  # Cap docstring length
    return ""


# ─── Python Extractor ───────────────────────────────────────────────

def _extract_python(source: bytes, tree, file_path: str) -> tuple[list[SymbolNode], list[Edge]]:
    nodes: list[SymbolNode] = []
    edges: list[Edge] = []
    file_uid = f"file:{file_path}"

    def visit(node, parent_uid: str | None = None, class_name: str | None = None):
        if node.type == "function_definition":
            name_node = node.child_by_field_name("name")
            if not name_node:
                return
            name = _extract_text(source, name_node)
            params_node = node.child_by_field_name("parameters")
            params = _extract_text(source, params_node) if params_node else "()"
            body_node = node.child_by_field_name("body")
            docstring = _get_docstring(source, body_node)

            ntype = NodeType.METHOD if class_name else NodeType.FUNCTION
            uid = SymbolNode.make_uid(file_path, name, node.start_point[0] + 1)

            sym = SymbolNode(
                uid=uid,
                name=name,
                node_type=ntype,
                file_path=file_path,
                line_start=node.start_point[0] + 1,
                line_end=node.end_point[0] + 1,
                language="python",
                docstring=docstring,
                signature=f"def {name}{params}",
                body_text=_extract_text(source, node)[:2000],
                properties={"class": class_name} if class_name else {},
            )
            nodes.append(sym)

            # DEFINED_IN edge
            edges.append(Edge(source_uid=uid, target_uid=file_uid, edge_type=EdgeType.DEFINED_IN))

            # HAS_METHOD edge if inside a class
            if parent_uid and class_name:
                edges.append(Edge(
                    source_uid=parent_uid, target_uid=uid, edge_type=EdgeType.HAS_METHOD
                ))

        elif node.type == "class_definition":
            name_node = node.child_by_field_name("name")
            if not name_node:
                return
            name = _extract_text(source, name_node)
            uid = SymbolNode.make_uid(file_path, name, node.start_point[0] + 1)

            # Extract base classes
            bases = []
            args_node = node.child_by_field_name("superclasses")
            if args_node:
                for child in args_node.children:
                    if child.type == "identifier":
                        bases.append(_extract_text(source, child))

            body_node = node.child_by_field_name("body")
            docstring = _get_docstring(source, body_node)

            sym = SymbolNode(
                uid=uid,
                name=name,
                node_type=NodeType.CLASS,
                file_path=file_path,
                line_start=node.start_point[0] + 1,
                line_end=node.end_point[0] + 1,
                language="python",
                docstring=docstring,
                signature=f"class {name}({', '.join(bases)})" if bases else f"class {name}",
                properties={"bases": bases},
            )
            nodes.append(sym)
            edges.append(Edge(source_uid=uid, target_uid=file_uid, edge_type=EdgeType.DEFINED_IN))

            # EXTENDS edges for base classes (resolved later by name)
            for base in bases:
                edges.append(Edge(
                    source_uid=uid,
                    target_uid=f"unresolved:{base}",
                    edge_type=EdgeType.EXTENDS,
                ))

            # Visit children for methods
            if body_node:
                for child in body_node.children:
                    visit(child, parent_uid=uid, class_name=name)
            return  # Don't recurse again

        # Recurse into children
        for child in node.children:
            visit(child, parent_uid, class_name)

    visit(tree.root_node)
    return nodes, edges


# ─── JavaScript/TypeScript Extractor ────────────────────────────────

def _extract_js_ts(
    source: bytes, tree, file_path: str, language: str
) -> tuple[list[SymbolNode], list[Edge]]:
    nodes: list[SymbolNode] = []
    edges: list[Edge] = []
    file_uid = f"file:{file_path}"

    def visit(node, parent_uid: str | None = None, class_name: str | None = None):
        # Function declarations: function foo() {}
        if node.type in ("function_declaration", "generator_function_declaration"):
            name_node = node.child_by_field_name("name")
            if not name_node:
                return
            name = _extract_text(source, name_node)
            params_node = node.child_by_field_name("parameters")
            params = _extract_text(source, params_node) if params_node else "()"

            uid = SymbolNode.make_uid(file_path, name, node.start_point[0] + 1)
            sym = SymbolNode(
                uid=uid,
                name=name,
                node_type=NodeType.FUNCTION,
                file_path=file_path,
                line_start=node.start_point[0] + 1,
                line_end=node.end_point[0] + 1,
                language=language,
                signature=f"function {name}{params}",
                body_text=_extract_text(source, node)[:2000],
            )
            nodes.append(sym)
            edges.append(Edge(source_uid=uid, target_uid=file_uid, edge_type=EdgeType.DEFINED_IN))

        # Arrow functions assigned to const/let/var
        elif node.type in ("lexical_declaration", "variable_declaration"):
            for decl in node.children:
                if decl.type == "variable_declarator":
                    name_node = decl.child_by_field_name("name")
                    value_node = decl.child_by_field_name("value")
                    if name_node and value_node and value_node.type == "arrow_function":
                        name = _extract_text(source, name_node)
                        params_node = value_node.child_by_field_name("parameters")
                        params = _extract_text(source, params_node) if params_node else "()"

                        uid = SymbolNode.make_uid(file_path, name, node.start_point[0] + 1)
                        sym = SymbolNode(
                            uid=uid,
                            name=name,
                            node_type=NodeType.FUNCTION,
                            file_path=file_path,
                            line_start=node.start_point[0] + 1,
                            line_end=node.end_point[0] + 1,
                            language=language,
                            signature=f"const {name} = {params} =>",
                            body_text=_extract_text(source, node)[:2000],
                        )
                        nodes.append(sym)
                        edges.append(Edge(
                            source_uid=uid, target_uid=file_uid, edge_type=EdgeType.DEFINED_IN
                        ))

        # Class declarations
        elif node.type == "class_declaration":
            name_node = node.child_by_field_name("name")
            if not name_node:
                return
            name = _extract_text(source, name_node)
            uid = SymbolNode.make_uid(file_path, name, node.start_point[0] + 1)

            # Heritage (extends)
            bases = []
            heritage = node.child_by_field_name("heritage")
            if heritage:
                for child in heritage.children:
                    if child.type == "identifier":
                        bases.append(_extract_text(source, child))

            sym = SymbolNode(
                uid=uid,
                name=name,
                node_type=NodeType.CLASS,
                file_path=file_path,
                line_start=node.start_point[0] + 1,
                line_end=node.end_point[0] + 1,
                language=language,
                signature=f"class {name}" + (f" extends {', '.join(bases)}" if bases else ""),
                properties={"bases": bases},
            )
            nodes.append(sym)
            edges.append(Edge(source_uid=uid, target_uid=file_uid, edge_type=EdgeType.DEFINED_IN))

            for base in bases:
                edges.append(Edge(
                    source_uid=uid,
                    target_uid=f"unresolved:{base}",
                    edge_type=EdgeType.EXTENDS,
                ))

            # Visit class body for methods
            body = node.child_by_field_name("body")
            if body:
                for child in body.children:
                    visit(child, parent_uid=uid, class_name=name)
            return

        # Methods inside classes
        elif node.type == "method_definition" and class_name:
            name_node = node.child_by_field_name("name")
            if not name_node:
                return
            name = _extract_text(source, name_node)
            params_node = node.child_by_field_name("parameters")
            params = _extract_text(source, params_node) if params_node else "()"

            uid = SymbolNode.make_uid(file_path, name, node.start_point[0] + 1)
            sym = SymbolNode(
                uid=uid,
                name=name,
                node_type=NodeType.METHOD,
                file_path=file_path,
                line_start=node.start_point[0] + 1,
                line_end=node.end_point[0] + 1,
                language=language,
                signature=f"{name}{params}",
                body_text=_extract_text(source, node)[:2000],
                properties={"class": class_name},
            )
            nodes.append(sym)
            edges.append(Edge(source_uid=uid, target_uid=file_uid, edge_type=EdgeType.DEFINED_IN))
            if parent_uid:
                edges.append(Edge(
                    source_uid=parent_uid, target_uid=uid, edge_type=EdgeType.HAS_METHOD
                ))

        # Recurse
        for child in node.children:
            visit(child, parent_uid, class_name)

    visit(tree.root_node)
    return nodes, edges


# ─── Public API ─────────────────────────────────────────────────────

def parse_file(file_node: FileNode) -> tuple[list[SymbolNode], list[Edge]]:
    """Parse a single file and extract symbols + edges."""
    import tree_sitter

    source = _read_file(file_node.path)
    if source is None:
        return [], []

    # Cache decoded source text so resolver can reuse it (no re-read from disk)
    file_node.source_text = source.decode("utf-8", errors="replace")

    lang = _get_language(file_node.language)
    if lang is None:
        return [], []

    parser = tree_sitter.Parser(lang)
    tree = parser.parse(source)

    if file_node.language == "python":
        return _extract_python(source, tree, file_node.relative_path)
    elif file_node.language in ("javascript", "typescript"):
        return _extract_js_ts(source, tree, file_node.relative_path, file_node.language)

    return [], []


def parse_all_files(files: list[FileNode]) -> tuple[list[SymbolNode], list[Edge]]:
    """Parse all files and return all symbols + edges."""
    from concurrent.futures import ThreadPoolExecutor, as_completed

    all_nodes: list[SymbolNode] = []
    all_edges: list[Edge] = []

    # Use threads (GIL is released during tree-sitter C parsing + file I/O)
    max_workers = min(8, max(1, (os.cpu_count() or 1)))

    file_results: dict[str, tuple[list[SymbolNode], list[Edge]]] = {}

    failed_files: list[str] = []
    PARSE_TIMEOUT = 30  # seconds per file — kill if tree-sitter hangs

    with ThreadPoolExecutor(max_workers=max_workers) as pool:
        future_to_file = {pool.submit(parse_file, f): f for f in files}
        for future in as_completed(future_to_file):
            f = future_to_file[future]
            try:
                nodes, edges = future.result(timeout=PARSE_TIMEOUT)
            except TimeoutError:
                failed_files.append(f"{f.relative_path}: timeout ({PARSE_TIMEOUT}s)")
                nodes, edges = [], []
            except Exception as exc:
                failed_files.append(f"{f.relative_path}: {exc}")
                nodes, edges = [], []
            f.symbols = nodes
            all_nodes.extend(nodes)
            all_edges.extend(edges)

    if failed_files:
        console.print(f"  [yellow]![/yellow] Failed to parse {len(failed_files)} file(s):")
        for msg in failed_files[:5]:
            console.print(f"    [dim]{msg}[/dim]")
        if len(failed_files) > 5:
            console.print(f"    [dim]... and {len(failed_files) - 5} more[/dim]")

    from collections import Counter
    type_counts = Counter(n.node_type for n in all_nodes)
    console.print(
        f"  [green]✓[/green] Parsing: {type_counts[NodeType.FUNCTION]} functions, "
        f"{type_counts[NodeType.CLASS]} classes, {type_counts[NodeType.METHOD]} methods"
    )
    return all_nodes, all_edges
