ai-robot-core/ai-service/app/core/qdrant_client.py

549 lines
20 KiB
Python

"""
Qdrant client for AI Service.
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
Supports multi-dimensional vectors for Matryoshka representation learning.
"""
import logging
from typing import Any
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class QdrantClient:
"""
[AC-AISVC-10, AC-AISVC-59] Qdrant client with tenant-isolated collection management.
Collection naming conventions:
- Legacy (single KB): kb_{tenantId}
- Multi-KB: kb_{tenantId}_{kbId}
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
"""
def __init__(self):
self._client: AsyncQdrantClient | None = None
self._collection_prefix = settings.qdrant_collection_prefix
self._vector_size = settings.qdrant_vector_size
async def get_client(self) -> AsyncQdrantClient:
"""Get or create Qdrant client instance."""
if self._client is None:
self._client = AsyncQdrantClient(url=settings.qdrant_url)
logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}")
return self._client
async def close(self) -> None:
"""Close Qdrant client connection."""
if self._client:
await self._client.close()
self._client = None
logger.info("Qdrant client connection closed")
def get_collection_name(self, tenant_id: str) -> str:
"""
[AC-AISVC-10] Get legacy collection name for a tenant.
Naming convention: kb_{tenantId}
Replaces @ with _ to ensure valid collection names.
Note: This is kept for backward compatibility.
For multi-KB, use get_kb_collection_name() instead.
"""
safe_tenant_id = tenant_id.replace('@', '_')
return f"{self._collection_prefix}{safe_tenant_id}"
def get_kb_collection_name(self, tenant_id: str, kb_id: str | None = None) -> str:
"""
[AC-AISVC-59, AC-AISVC-63] Get collection name for a specific knowledge base.
Naming convention:
- If kb_id is None or "default": kb_{tenantId} (legacy format for backward compatibility)
- Otherwise: kb_{tenantId}_{kbId}
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID (optional, defaults to legacy naming)
Returns:
Collection name for the knowledge base
"""
safe_tenant_id = tenant_id.replace('@', '_')
if kb_id is None or kb_id == "default" or kb_id == "":
return f"{self._collection_prefix}{safe_tenant_id}"
safe_kb_id = kb_id.replace('-', '_')[:8]
return f"{self._collection_prefix}{safe_tenant_id}_{safe_kb_id}"
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
"""
[AC-AISVC-10] Ensure collection exists for tenant (legacy single-KB mode).
Supports multi-dimensional vectors for Matryoshka retrieval.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
exists = await client.collection_exists(collection_name)
if not exists:
if use_multi_vector:
vectors_config = {
"full": VectorParams(
size=768,
distance=Distance.COSINE,
),
"dim_256": VectorParams(
size=256,
distance=Distance.COSINE,
),
"dim_512": VectorParams(
size=512,
distance=Distance.COSINE,
),
}
else:
vectors_config = VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
)
await client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
f"with multi_vector={use_multi_vector}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}")
return False
async def ensure_kb_collection_exists(
self,
tenant_id: str,
kb_id: str | None = None,
use_multi_vector: bool = True,
) -> bool:
"""
[AC-AISVC-59] Ensure collection exists for a specific knowledge base.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID (optional, defaults to legacy naming)
use_multi_vector: Whether to use multi-dimensional vectors
Returns:
True if collection exists or was created successfully
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await client.collection_exists(collection_name)
if not exists:
if use_multi_vector:
vectors_config = {
"full": VectorParams(
size=768,
distance=Distance.COSINE,
),
"dim_256": VectorParams(
size=256,
distance=Distance.COSINE,
),
"dim_512": VectorParams(
size=512,
distance=Distance.COSINE,
),
}
else:
vectors_config = VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
)
await client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(
f"[AC-AISVC-59] Created KB collection: {collection_name} for tenant={tenant_id}, kb_id={kb_id} "
f"with multi_vector={use_multi_vector}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-59] Error ensuring KB collection: {e}")
return False
async def upsert_vectors(
self,
tenant_id: str,
points: list[PointStruct],
kb_id: str | None = None,
) -> bool:
"""
[AC-AISVC-10, AC-AISVC-63] Upsert vectors into tenant's collection.
Args:
tenant_id: Tenant identifier
points: List of PointStruct to upsert
kb_id: Knowledge base ID (optional, uses legacy naming if not provided)
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
await client.upsert(
collection_name=collection_name,
points=points,
)
logger.info(
f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}, kb_id={kb_id}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
return False
async def upsert_multi_vector(
self,
tenant_id: str,
points: list[dict[str, Any]],
kb_id: str | None = None,
) -> bool:
"""
Upsert points with multi-dimensional vectors.
Args:
tenant_id: Tenant identifier
points: List of points with format:
{
"id": str | int,
"vector": {
"full": [768 floats],
"dim_256": [256 floats],
"dim_512": [512 floats],
},
"payload": dict
}
kb_id: Knowledge base ID (optional, uses legacy naming if not provided)
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
qdrant_points = []
for p in points:
point = PointStruct(
id=p["id"],
vector=p["vector"],
payload=p.get("payload", {}),
)
qdrant_points.append(point)
await client.upsert(
collection_name=collection_name,
points=qdrant_points,
)
logger.info(
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}, kb_id={kb_id}"
)
return True
except Exception as e:
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
return False
async def search(
self,
tenant_id: str,
query_vector: list[float],
limit: int = 5,
score_threshold: float | None = None,
vector_name: str = "full",
with_vectors: bool = False,
) -> list[dict[str, Any]]:
"""
[AC-AISVC-10] Search vectors in tenant's collection.
Returns results with score >= score_threshold if specified.
Searches both old format (with @) and new format (with _) for backward compatibility.
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
limit: Maximum number of results
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup.
with_vectors: Whether to return vectors in results (for two-stage reranking)
"""
client = await self.get_client()
logger.info(
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
)
collection_names = [self.get_collection_name(tenant_id)]
if '@' in tenant_id:
old_format = f"{self._collection_prefix}{tenant_id}"
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
collection_names = [new_format, old_format]
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
all_hits = []
for collection_name in collection_names:
try:
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
exists = await client.collection_exists(collection_name)
if not exists:
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
continue
try:
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
using=vector_name,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
logger.info(
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
f"trying without vector name (single-vector mode)"
)
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
else:
raise
logger.info(
f"[AC-AISVC-10] Collection {collection_name} returned {len(results.points)} raw results"
)
hits = []
for result in results.points:
hit = {
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
if with_vectors and result.vector:
hit["vector"] = result.vector
hits.append(hit)
all_hits.extend(hits)
if hits:
logger.info(
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
)
for i, h in enumerate(hits[:3]):
logger.debug(
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
)
else:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
)
except Exception as e:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
)
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
logger.info(
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
)
if len(all_hits) == 0:
logger.warning(
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
f"collections_tried={collection_names}, limit={limit}"
)
return all_hits
async def delete_collection(self, tenant_id: str) -> bool:
"""
[AC-AISVC-10] Delete tenant's collection.
Used when tenant is removed.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
await client.delete_collection(collection_name=collection_name)
logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}")
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error deleting collection: {e}")
return False
async def delete_kb_collection(self, tenant_id: str, kb_id: str) -> bool:
"""
[AC-AISVC-62] Delete a specific knowledge base's collection.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
Returns:
True if collection was deleted successfully
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await client.collection_exists(collection_name)
if exists:
await client.delete_collection(collection_name=collection_name)
logger.info(f"[AC-AISVC-62] Deleted KB collection: {collection_name} for kb_id={kb_id}")
else:
logger.info(f"[AC-AISVC-62] KB collection {collection_name} does not exist, nothing to delete")
return True
except Exception as e:
logger.error(f"[AC-AISVC-62] Error deleting KB collection: {e}")
return False
async def search_kb(
self,
tenant_id: str,
query_vector: list[float],
kb_ids: list[str] | None = None,
limit: int = 5,
score_threshold: float | None = None,
vector_name: str = "full",
with_vectors: bool = False,
) -> list[dict[str, Any]]:
"""
[AC-AISVC-64] Search vectors across multiple knowledge base collections.
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
kb_ids: List of knowledge base IDs to search. If None, searches legacy collection.
limit: Maximum number of results per collection
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search
with_vectors: Whether to return vectors in results
Returns:
Combined and sorted results from all collections
"""
client = await self.get_client()
if kb_ids is None or len(kb_ids) == 0:
return await self.search(
tenant_id=tenant_id,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold,
vector_name=vector_name,
with_vectors=with_vectors,
)
logger.info(
f"[AC-AISVC-64] Starting multi-KB search: tenant_id={tenant_id}, "
f"kb_ids={kb_ids}, limit={limit}, score_threshold={score_threshold}"
)
all_hits = []
for kb_id in kb_ids:
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await client.collection_exists(collection_name)
if not exists:
logger.warning(f"[AC-AISVC-64] Collection {collection_name} does not exist")
continue
try:
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
using=vector_name,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
else:
raise
for result in results.points:
hit = {
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
"kb_id": kb_id,
}
if with_vectors and result.vector:
hit["vector"] = result.vector
all_hits.append(hit)
logger.info(
f"[AC-AISVC-64] Collection {collection_name} returned {len(results.points)} results"
)
except Exception as e:
logger.warning(f"[AC-AISVC-64] Error searching collection {collection_name}: {e}")
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
logger.info(
f"[AC-AISVC-64] Multi-KB search returned {len(all_hits)} total results"
)
return all_hits
_qdrant_client: QdrantClient | None = None
async def get_qdrant_client() -> QdrantClient:
"""Get or create Qdrant client instance."""
global _qdrant_client
if _qdrant_client is None:
_qdrant_client = QdrantClient()
return _qdrant_client
async def close_qdrant_client() -> None:
"""Close Qdrant client connection."""
global _qdrant_client
if _qdrant_client:
await _qdrant_client.close()
_qdrant_client = None