"""Vector store — Qdrant client for storing and searching embeddings.

Supports 3 modes (auto-detected):
  1. "memory"  — in-process, no server needed (default, zero setup)
  2. "disk"    — local file storage, persists between restarts (no server needed)
  3. "server"  — remote Qdrant server via URL (for production / Docker)
"""

from __future__ import annotations

import hashlib
from pathlib import Path

try:
    from qdrant_client import QdrantClient
    from qdrant_client.models import (
        Distance,
        PointStruct,
        VectorParams,
        Filter,
        FieldCondition,
        MatchValue,
    )
    _HAS_QDRANT = True
except ImportError:
    _HAS_QDRANT = False

from ...config import settings


COLLECTION_NAME = "codegraph_symbols"


class VectorStore:
    """Qdrant-backed vector store for symbol embeddings.

    Modes:
      - qdrant_mode="memory": runs entirely in-process (no Docker, no install)
      - qdrant_mode="disk": saves to local folder, survives restarts
      - qdrant_mode="server": connects to Qdrant server (Docker/cloud)
    """

    def __init__(
        self,
        mode: str | None = None,
        url: str | None = None,
        path: str | None = None,
        dimension: int | None = None,
    ):
        if not _HAS_QDRANT:
            raise ImportError(
                "qdrant_client is required for vector search. "
                "Install it with: pip install qdrant-client"
            )
        self.mode = mode or settings.qdrant_mode
        self.url = url or settings.qdrant_url
        self.path = path or settings.qdrant_path
        self.dimension = dimension or settings.embedding_dimension
        self._client = None

    @property
    def client(self) -> QdrantClient:
        if self._client is None:
            if self.mode == "server":
                self._client = QdrantClient(url=self.url)
            elif self.mode == "disk":
                storage_path = str(Path(self.path).resolve())
                Path(storage_path).mkdir(parents=True, exist_ok=True)
                self._client = QdrantClient(path=storage_path)
            else:
                # memory mode — fastest, no persistence
                self._client = QdrantClient(location=":memory:")
        return self._client

    def ensure_collection(self):
        """Create collection if it doesn't exist."""
        collections = [c.name for c in self.client.get_collections().collections]
        if COLLECTION_NAME not in collections:
            self.client.create_collection(
                collection_name=COLLECTION_NAME,
                vectors_config=VectorParams(
                    size=self.dimension,
                    distance=Distance.COSINE,
                ),
            )

    def upsert(self, ids: list[str], vectors: list[list[float]], payloads: list[dict]):
        """Insert or update vectors with payloads."""
        self.ensure_collection()

        points = []
        for uid, vector, payload in zip(ids, vectors, payloads):
            point_id = self._hash_id(uid)
            payload["uid"] = uid
            points.append(PointStruct(id=point_id, vector=vector, payload=payload))

        batch_size = 500
        for i in range(0, len(points), batch_size):
            self.client.upsert(
                collection_name=COLLECTION_NAME,
                points=points[i : i + batch_size],
            )

    def search(
        self,
        query_vector: list[float],
        limit: int = 10,
        node_type: str | None = None,
    ) -> list[dict]:
        """Search for nearest vectors."""
        self.ensure_collection()

        query_filter = None
        if node_type:
            query_filter = Filter(
                must=[FieldCondition(key="node_type", match=MatchValue(value=node_type))]
            )

        results = self.client.query_points(
            collection_name=COLLECTION_NAME,
            query=query_vector,
            limit=limit,
            query_filter=query_filter,
            with_payload=True,
        )

        return [
            {**hit.payload, "_score": hit.score}
            for hit in results.points
        ]

    def delete_collection(self):
        """Drop the collection (for re-indexing)."""
        try:
            self.client.delete_collection(COLLECTION_NAME)
        except Exception:
            pass

    def count(self) -> int:
        """Return number of vectors in the collection."""
        try:
            info = self.client.get_collection(COLLECTION_NAME)
            return info.points_count
        except Exception:
            return 0

    @staticmethod
    def _hash_id(uid: str) -> int:
        """Convert string UID to a positive integer for Qdrant."""
        return int(hashlib.sha256(uid.encode()).hexdigest()[:16], 16)
