"""Hybrid search — Reciprocal Rank Fusion (RRF) of BM25 + vector search."""

from __future__ import annotations

import heapq
from typing import TYPE_CHECKING

from .bm25_index import SymbolSearchIndex
from .embedder import Embedder

if TYPE_CHECKING:
    from .vector_store import VectorStore


class HybridSearch:
    """Combines BM25 keyword search with vector semantic search using RRF.

    RRF formula: score = Σ 1/(k + rank_i) where k=60
    This produces a unified ranking that benefits from both exact keyword
    matches and semantic similarity.
    """

    def __init__(
        self,
        bm25: SymbolSearchIndex,
        vector_store: VectorStore,
        embedder: Embedder,
        rrf_k: int = 60,
    ):
        self.bm25 = bm25
        self.vector_store = vector_store
        self.embedder = embedder
        self.rrf_k = rrf_k

    def search(self, query: str, limit: int = 10) -> list[dict]:
        """Run hybrid search combining BM25 and vector results."""
        # Fetch more candidates from each source for better fusion
        fetch_limit = limit * 3

        # BM25 keyword search
        bm25_results = self.bm25.search(query, limit=fetch_limit)

        # Vector semantic search
        vector_results = self._vector_search(query, limit=fetch_limit)

        # Fuse with RRF
        return self._rrf_merge(bm25_results, vector_results, limit)

    def _vector_search(self, query: str, limit: int) -> list[dict]:
        """Run vector similarity search."""
        try:
            query_vector = self.embedder.embed_text(query)
            return self.vector_store.search(query_vector, limit=limit)
        except Exception:
            return []

    def _rrf_merge(
        self, bm25_results: list[dict], vector_results: list[dict], limit: int = 10,
    ) -> list[dict]:
        """Merge two ranked lists using Reciprocal Rank Fusion."""
        k = self.rrf_k
        scores: dict[str, float] = {}
        docs: dict[str, dict] = {}

        # Score BM25 results
        for rank, doc in enumerate(bm25_results):
            uid = doc.get("uid", doc.get("name", str(rank)))
            scores[uid] = scores.get(uid, 0) + 1.0 / (k + rank + 1)
            if uid not in docs:
                docs[uid] = doc

        # Score vector results
        for rank, doc in enumerate(vector_results):
            uid = doc.get("uid", doc.get("name", str(rank)))
            scores[uid] = scores.get(uid, 0) + 1.0 / (k + rank + 1)
            if uid not in docs:
                docs[uid] = doc

        # heapq top-k: O(n log k) instead of O(n log n) full sort
        top_k = heapq.nlargest(limit, scores.items(), key=lambda x: x[1])

        results = []
        for uid, score in top_k:
            doc = docs[uid].copy()
            doc["_score"] = round(score, 6)
            doc.pop("_bm25_rank", None)
            doc.pop("_vector_rank", None)
            results.append(doc)

        return results
