"""Document chunker — ingests .md, .txt, .pdf files.

Chunks documents into fixed-size pieces and links them
to code symbols via keyword matching.

Fully streaming: reads line-by-line, yields chunks immediately,
builds nodes and links on-the-fly. Memory usage is O(CHUNK_SIZE)
regardless of file size.
"""

from __future__ import annotations

import os
import re
from typing import Generator

from ...config import settings
from ..models import Edge, EdgeType, NodeType, SymbolNode
from ._console import console

# Document extensions to process
DOC_EXTENSIONS = {".md", ".txt", ".rst", ".adoc"}

CHUNK_SIZE = 300       # tokens (approx words)
CHUNK_OVERLAP = 50     # overlap between chunks

# Precompiled patterns
_MD_HEADER_RE = re.compile(r"^#{1,4}\s+(.+)$")
_WORD_SPLIT_RE = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")


def _find_doc_files(root_path: str) -> list[str]:
    """Find all document files in the project."""
    docs = []
    ignored = set(settings.ignored_dirs)

    for dirpath, dirnames, filenames in os.walk(root_path):
        dirnames[:] = [d for d in dirnames if d not in ignored and not d.startswith(".")]

        for f in filenames:
            ext = os.path.splitext(f)[1].lower()
            if ext in DOC_EXTENSIONS:
                docs.append(os.path.join(dirpath, f))

    return docs


def _stream_lines(path: str) -> Generator[str, None, None]:
    """Yield lines from a file without loading it all into memory."""
    try:
        with open(path, "r", encoding="utf-8", errors="replace") as f:
            for line in f:
                yield line
    except OSError:
        return


def _iter_chunks_markdown(path: str) -> Generator[dict, None, None]:
    """Yield markdown chunks one-by-one via streaming.

    Reads line-by-line, flushes a chunk as soon as word buffer
    reaches CHUNK_SIZE. Memory: O(CHUNK_SIZE) always.
    """
    current_header = ""
    word_buf: list[str] = []
    overlap_buf: list[str] = []

    for line in _stream_lines(path):
        header_match = _MD_HEADER_RE.match(line)

        if header_match:
            # New section — flush previous buffer
            if word_buf:
                yield {"text": " ".join(word_buf), "header": current_header}
                word_buf = []
            overlap_buf = []  # no overlap across sections
            current_header = header_match.group(1).strip()
            word_buf.extend(line.split())
        else:
            words = line.split()
            if not words:
                continue
            word_buf.extend(words)

            if len(word_buf) >= CHUNK_SIZE:
                yield {"text": " ".join(word_buf), "header": current_header}
                overlap_buf = word_buf[-CHUNK_OVERLAP:] if len(word_buf) > CHUNK_OVERLAP else word_buf[:]
                word_buf = overlap_buf[:]
                overlap_buf = []

    # Flush remaining
    if word_buf:
        yield {"text": " ".join(word_buf), "header": current_header}


def _iter_chunks_plain(path: str) -> Generator[dict, None, None]:
    """Yield plain-text chunks one-by-one via streaming."""
    word_buf: list[str] = []
    overlap_buf: list[str] = []

    for line in _stream_lines(path):
        words = line.split()
        if not words:
            continue

        word_buf.extend(words)

        if len(word_buf) >= CHUNK_SIZE:
            yield {"text": " ".join(word_buf), "header": ""}
            overlap_buf = word_buf[-CHUNK_OVERLAP:] if len(word_buf) > CHUNK_OVERLAP else word_buf[:]
            word_buf = overlap_buf[:]
            overlap_buf = []

    if word_buf:
        yield {"text": " ".join(word_buf), "header": ""}


def _build_symbol_index(code_symbols: list[SymbolNode]) -> dict[str, SymbolNode]:
    """Build lowered-name → SymbolNode lookup (once, reused for all files)."""
    _code_types = frozenset((NodeType.FUNCTION, NodeType.CLASS, NodeType.METHOD))
    index: dict[str, SymbolNode] = {}
    for sym in code_symbols:
        if sym.node_type in _code_types:
            if len(sym.name) >= 3 and sym.name not in ("__init__", "self", "main"):
                index[sym.name.lower()] = sym
    return index


def _link_node_to_symbols(
    node: SymbolNode,
    symbol_index: dict[str, SymbolNode],
) -> list[Edge]:
    """Link a single doc chunk node to matching code symbols."""
    if not node.body_text or not symbol_index:
        return []

    words = set(_WORD_SPLIT_RE.findall(node.body_text.lower()))
    matched = words & symbol_index.keys()

    return [
        Edge(
            source_uid=node.uid,
            target_uid=symbol_index[name].uid,
            edge_type=EdgeType.CONTAINS,
            properties={"relationship": "describes"},
        )
        for name in matched
    ]


def ingest_documents(
    root_path: str,
    code_symbols: list[SymbolNode],
) -> tuple[list[SymbolNode], list[Edge]]:
    """Ingest all document files, chunk them, and link to code graph.

    Fully streaming per file: each chunk is yielded by a generator,
    immediately converted to a SymbolNode, and linked to code symbols.
    No intermediate list of all raw chunks is ever built.

    Returns:
        (doc_nodes, link_edges)
    """
    doc_files = _find_doc_files(root_path)
    if not doc_files:
        console.print("  [dim]No documents found[/dim]")
        return [], []

    symbol_index = _build_symbol_index(code_symbols)

    all_nodes: list[SymbolNode] = []
    all_edges: list[Edge] = []
    total_chunks = 0

    for doc_path in doc_files:
        rel_path = os.path.relpath(doc_path, root_path).replace("\\", "/")
        ext = os.path.splitext(doc_path)[1].lower()

        # Pick the right streaming iterator
        chunk_iter = _iter_chunks_markdown(doc_path) if ext == ".md" else _iter_chunks_plain(doc_path)

        file_chunk_count = 0

        # Stream: consume one chunk at a time, build node + link immediately
        for i, chunk in enumerate(chunk_iter):
            uid = f"doc:{rel_path}:chunk_{i}"
            node = SymbolNode(
                uid=uid,
                name=chunk.get("header") or f"{os.path.basename(doc_path)} §{i+1}",
                node_type=NodeType.VARIABLE,
                file_path=rel_path,
                line_start=i * CHUNK_SIZE,
                line_end=(i + 1) * CHUNK_SIZE,
                language="document",
                docstring=chunk["text"][:200],
                body_text=chunk["text"],
                properties={"is_doc_chunk": True, "chunk_index": i},
            )
            all_nodes.append(node)

            # Link to code symbols immediately — no second pass needed
            edges = _link_node_to_symbols(node, symbol_index)
            all_edges.extend(edges)

            file_chunk_count += 1

        if file_chunk_count == 0:
            continue

        total_chunks += file_chunk_count

        # File-level doc node
        file_uid = f"doc:{rel_path}"
        file_node = SymbolNode(
            uid=file_uid,
            name=os.path.basename(doc_path),
            node_type=NodeType.FILE,
            file_path=rel_path,
            line_start=0,
            line_end=0,
            language="document",
            properties={"is_document": True, "chunks": file_chunk_count},
        )
        all_nodes.append(file_node)

    console.print(
        f"  [green]✓[/green] Documents: {len(doc_files)} files, "
        f"{total_chunks} chunks, {len(all_edges)} links to code"
    )

    return all_nodes, all_edges
