"""BM25 search index for fast keyword search over symbols."""

from __future__ import annotations

import heapq
import re
from rank_bm25 import BM25Okapi

_CAMEL_RE = re.compile(r"([a-z])([A-Z])")
_TOKEN_RE = re.compile(r"[a-zA-Z][a-zA-Z0-9]*")


class SymbolSearchIndex:
    """Hybrid search: BM25 keyword search over code symbols."""

    def __init__(self):
        self._documents: list[dict] = []
        self._corpus: list[list[str]] = []
        self._bm25: BM25Okapi | None = None

    def build(self, symbols: list[dict]):
        """Build the BM25 index from symbol dicts."""
        self._documents = symbols
        self._corpus = [
            self._tokenize(
                " ".join(
                    sym.get(k, "") or ""
                    for k in ("name", "signature", "docstring", "file_path", "node_type")
                )
            )
            for sym in symbols
        ]

        if self._corpus:
            self._bm25 = BM25Okapi(self._corpus)

    def search(self, query: str, limit: int = 10) -> list[dict]:
        """Search for symbols matching query."""
        if not self._bm25 or not self._documents:
            return []

        tokens = self._tokenize(query)
        scores = self._bm25.get_scores(tokens)

        # Use heapq for O(n log k) instead of O(n log n) full sort
        top_k = heapq.nlargest(
            limit,
            ((score, i) for i, score in enumerate(scores) if score > 0),
        )

        results = []
        for score, idx in top_k:
            doc = self._documents[idx].copy()
            doc["_score"] = round(float(score), 4)
            results.append(doc)

        return results

    @staticmethod
    def _tokenize(text: str) -> list[str]:
        """Split text into searchable tokens (alphanumeric only, camelCase split)."""
        # Split camelCase first, then extract only alphanumeric tokens
        text = _CAMEL_RE.sub(r"\1 \2", text)
        text = text.replace("_", " ").replace("/", " ").replace(".", " ")
        return [t.lower() for t in _TOKEN_RE.findall(text) if len(t) > 1]
