"""Embedding model wrapper — local sentence-transformers only.

Uses all-MiniLM-L6-v2 (~80MB) by default. Runs on CPU, free,
no API key needed. Claude handles all AI analysis via MCP tools.
"""

from __future__ import annotations

import hashlib
import json
from pathlib import Path

from ...config import settings


class Embedder:
    """Generate vector embeddings for text using local sentence-transformers.

    Free, runs on CPU, no API key needed. ~80MB model download on first use.
    """

    def __init__(self, cache_dir: str | None = None):
        self.dimension = settings.embedding_dimension
        self._model = None
        self._cache_dir = Path(cache_dir) if cache_dir else None

    def _init_model(self):
        try:
            from sentence_transformers import SentenceTransformer
        except ImportError:
            raise ImportError(
                "sentence-transformers not installed. "
                "Install with: pip install sentence-transformers"
            )
        self._model = SentenceTransformer(settings.local_embedding_model)
        self.dimension = self._model.get_sentence_embedding_dimension()

    def embed_texts(self, texts: list[str]) -> list[list[float]]:
        """Embed a batch of texts into vectors."""
        if not texts:
            return []

        # Check cache
        cached, missing_indices = self._check_cache(texts)
        if not missing_indices:
            return cached

        # Compute missing embeddings
        missing_texts = [texts[i] for i in missing_indices]
        new_vectors = self._embed_local(missing_texts)

        # Fill in results
        for idx, vec in zip(missing_indices, new_vectors):
            cached[idx] = vec

        # Save to cache
        self._save_cache(texts, cached)

        return cached

    def embed_text(self, text: str) -> list[float]:
        """Embed a single text."""
        return self.embed_texts([text])[0]

    def _embed_local(self, texts: list[str]) -> list[list[float]]:
        if self._model is None:
            self._init_model()

        # Encode in smaller batches to avoid memory spikes
        all_vectors = []
        batch_size = 128
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            vecs = self._model.encode(batch, show_progress_bar=False, normalize_embeddings=True)
            all_vectors.extend(v.tolist() for v in vecs)
        return all_vectors

    # ─── Cache ─────────────────────────────────────────────────────

    def _cache_key(self, text: str) -> str:
        return hashlib.md5(text.encode()).hexdigest()

    def _check_cache(self, texts: list[str]) -> tuple[list[list[float] | None], list[int]]:
        """Return cached vectors and indices of missing ones."""
        if not self._cache_dir:
            return [None] * len(texts), list(range(len(texts)))

        results: list[list[float] | None] = [None] * len(texts)
        missing = []

        for i, text in enumerate(texts):
            cache_file = self._cache_dir / f"{self._cache_key(text)}.json"
            if cache_file.exists():
                results[i] = json.loads(cache_file.read_text())
            else:
                missing.append(i)

        return results, missing

    def _save_cache(self, texts: list[str], vectors: list[list[float] | None]):
        if not self._cache_dir:
            return

        self._cache_dir.mkdir(parents=True, exist_ok=True)
        for text, vec in zip(texts, vectors):
            if vec is not None:
                cache_file = self._cache_dir / f"{self._cache_key(text)}.json"
                cache_file.write_text(json.dumps(vec))
