feat: RAG 检索优化,实现多维度向量存储和 Nomic 嵌入提供者 [AC-AISVC-16, AC-AISVC-29]

This commit is contained in:
MerCry 2026-02-25 23:10:12 +08:00
parent 774744d534
commit cee884d9a0
12 changed files with 2007 additions and 47 deletions

View File

@ -0,0 +1,330 @@
"""
Knowledge base management API with RAG optimization features.
Reference: rag-optimization/spec.md Section 4.2
"""
import logging
from datetime import date
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.services.retrieval import (
ChunkMetadata,
ChunkMetadataModel,
IndexingProgress,
IndexingResult,
KnowledgeIndexer,
MetadataFilter,
RetrievalStrategy,
get_knowledge_indexer,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/kb", tags=["Knowledge Base"])
class IndexDocumentRequest(BaseModel):
"""Request to index a document."""
tenant_id: str = Field(..., description="Tenant ID")
document_id: str = Field(..., description="Document ID")
text: str = Field(..., description="Document text content")
metadata: ChunkMetadataModel | None = Field(default=None, description="Document metadata")
class IndexDocumentResponse(BaseModel):
"""Response from document indexing."""
success: bool
total_chunks: int
indexed_chunks: int
failed_chunks: int
elapsed_seconds: float
error_message: str | None = None
class IndexingProgressResponse(BaseModel):
"""Response with current indexing progress."""
total_chunks: int
processed_chunks: int
failed_chunks: int
progress_percent: int
elapsed_seconds: float
current_document: str
class MetadataFilterRequest(BaseModel):
"""Request for metadata filtering."""
categories: list[str] | None = None
target_audiences: list[str] | None = None
departments: list[str] | None = None
valid_only: bool = True
min_priority: int | None = None
keywords: list[str] | None = None
class RetrieveRequest(BaseModel):
"""Request for knowledge retrieval."""
tenant_id: str = Field(..., description="Tenant ID")
query: str = Field(..., description="Search query")
top_k: int = Field(default=10, ge=1, le=50, description="Number of results")
filters: MetadataFilterRequest | None = Field(default=None, description="Metadata filters")
strategy: RetrievalStrategy = Field(default=RetrievalStrategy.HYBRID, description="Retrieval strategy")
class RetrieveResponse(BaseModel):
"""Response from knowledge retrieval."""
hits: list[dict[str, Any]]
total_hits: int
max_score: float
is_insufficient: bool
diagnostics: dict[str, Any]
class MetadataOptionsResponse(BaseModel):
"""Response with available metadata options."""
categories: list[str]
departments: list[str]
target_audiences: list[str]
priorities: list[int]
@router.post("/index", response_model=IndexDocumentResponse)
async def index_document(
request: IndexDocumentRequest,
session: AsyncSession = Depends(get_session),
):
"""
Index a document with optimized embedding.
Features:
- Task prefixes (search_document:) for document embedding
- Multi-dimensional vectors (256/512/768)
- Metadata support
"""
try:
index = get_knowledge_indexer()
chunk_metadata = None
if request.metadata:
chunk_metadata = ChunkMetadata(
category=request.metadata.category,
subcategory=request.metadata.subcategory,
target_audience=request.metadata.target_audience,
source_doc=request.metadata.source_doc,
source_url=request.metadata.source_url,
department=request.metadata.department,
priority=request.metadata.priority,
keywords=request.metadata.keywords,
)
result = await index.index_document(
tenant_id=request.tenant_id,
document_id=request.document_id,
text=request.text,
metadata=chunk_metadata,
)
return IndexDocumentResponse(
success=result.success,
total_chunks=result.total_chunks,
indexed_chunks=result.indexed_chunks,
failed_chunks=result.failed_chunks,
elapsed_seconds=result.elapsed_seconds,
error_message=result.error_message,
)
except Exception as e:
logger.error(f"[KB-API] Failed to index document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"索引失败: {str(e)}"
)
@router.get("/index/progress", response_model=IndexingProgressResponse | None)
async def get_indexing_progress():
"""Get current indexing progress."""
try:
index = get_knowledge_indexer()
progress = index.get_progress()
if progress is None:
return None
return IndexingProgressResponse(
total_chunks=progress.total_chunks,
processed_chunks=progress.processed_chunks,
failed_chunks=progress.failed_chunks,
progress_percent=progress.progress_percent,
elapsed_seconds=progress.elapsed_seconds,
current_document=progress.current_document,
)
except Exception as e:
logger.error(f"[KB-API] Failed to get progress: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取进度失败: {str(e)}"
)
@router.post("/retrieve", response_model=RetrieveResponse)
async def retrieve_knowledge(request: RetrieveRequest):
"""
Retrieve knowledge using optimized RAG.
Strategies:
- vector: Simple vector search
- bm25: BM25 keyword search
- hybrid: RRF combination of vector + BM25 (default)
- two_stage: Two-stage retrieval with Matryoshka dimensions
"""
try:
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
retriever = await get_optimized_retriever()
metadata_filter = None
if request.filters:
filter_dict = request.filters.model_dump(exclude_none=True)
metadata_filter = MetadataFilter(**filter_dict)
ctx = RetrievalContext(
tenant_id=request.tenant_id,
query=request.query,
)
if metadata_filter:
ctx.metadata = {"filter": metadata_filter.to_qdrant_filter()}
result = await retriever.retrieve(ctx)
return RetrieveResponse(
hits=[
{
"text": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
],
total_hits=result.hit_count,
max_score=result.max_score,
is_insufficient=result.diagnostics.get("is_insufficient", False),
diagnostics=result.diagnostics or {},
)
except Exception as e:
logger.error(f"[KB-API] Failed to retrieve: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"检索失败: {str(e)}"
)
@router.get("/metadata/options", response_model=MetadataOptionsResponse)
async def get_metadata_options():
"""
Get available metadata options for filtering.
These would typically be loaded from a database.
"""
try:
return MetadataOptionsResponse(
categories=[
"课程咨询",
"考试政策",
"学籍管理",
"奖助学金",
"宿舍管理",
"校园服务",
"就业指导",
"其他",
],
departments=[
"教务处",
"学生处",
"财务处",
"后勤处",
"就业指导中心",
"图书馆",
"信息中心",
],
target_audiences=[
"本科生",
"研究生",
"留学生",
"新生",
"毕业生",
"教职工",
],
priorities=list(range(1, 11)),
)
except Exception as e:
logger.error(f"[KB-API] Failed to get metadata options: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取选项失败: {str(e)}"
)
@router.post("/reindex")
async def reindex_all(
tenant_id: str,
session: AsyncSession = Depends(get_session),
):
"""
Reindex all documents for a tenant with optimized embedding.
This would typically read from the documents table and reindex.
"""
try:
from app.models.entities import Document, DocumentStatus
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.status == DocumentStatus.COMPLETED.value,
)
result = await session.execute(stmt)
documents = result.scalars().all()
index = get_knowledge_indexer()
total_indexed = 0
total_failed = 0
for doc in documents:
if doc.file_path:
import os
if os.path.exists(doc.file_path):
with open(doc.file_path, 'r', encoding='utf-8') as f:
text = f.read()
result = await index.index_document(
tenant_id=tenant_id,
document_id=str(doc.id),
text=text,
)
total_indexed += result.indexed_chunks
total_failed += result.failed_chunks
return {
"success": True,
"total_documents": len(documents),
"total_indexed": total_indexed,
"total_failed": total_failed,
}
except Exception as e:
logger.error(f"[KB-API] Failed to reindex: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重新索引失败: {str(e)}"
)

View File

@ -17,6 +17,7 @@ from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.services.retrieval.vector_retriever import get_vector_retriever
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
from app.services.llm.factory import get_llm_config_manager
@ -91,7 +92,8 @@ async def run_rag_experiment(
threshold = request.score_threshold or settings.rag_score_threshold
try:
retriever = await get_vector_retriever()
# Use optimized retriever with RAG enhancements
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
@ -199,7 +201,8 @@ async def run_rag_experiment_stream(
async def event_generator():
try:
retriever = await get_vector_retriever()
# Use optimized retriever with RAG enhancements
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,

View File

@ -9,18 +9,43 @@ from typing import Annotated, Any
from fastapi import APIRouter, Depends, Header, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.middleware import get_response_mode, is_sse_request
from app.core.sse import SSEStateMachine, create_error_event
from app.core.tenant import get_tenant_id
from app.models import ChatRequest, ChatResponse, ErrorResponse
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
from app.services.memory import MemoryService
from app.services.orchestrator import OrchestratorService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["AI Chat"])
async def get_orchestrator_service_with_memory(
session: Annotated[AsyncSession, Depends(get_session)]
) -> OrchestratorService:
"""
[AC-AISVC-13] Create orchestrator service with memory service and LLM client.
Ensures each request has a fresh MemoryService with database session.
"""
from app.services.llm.factory import get_llm_config_manager
from app.services.retrieval.vector_retriever import get_vector_retriever
memory_service = MemoryService(session)
llm_config_manager = get_llm_config_manager()
llm_client = llm_config_manager.get_client()
retriever = await get_vector_retriever()
return OrchestratorService(
llm_client=llm_client,
memory_service=memory_service,
retriever=retriever,
)
@router.post(
"/ai/chat",
operation_id="generateReply",
@ -49,7 +74,7 @@ async def generate_reply(
request: Request,
chat_request: ChatRequest,
accept: Annotated[str | None, Header()] = None,
orchestrator: OrchestratorService = Depends(get_orchestrator_service),
orchestrator: OrchestratorService = Depends(get_orchestrator_service_with_memory),
) -> Any:
"""
[AC-AISVC-06] Generate AI reply with automatic response mode switching.

View File

@ -1,13 +1,14 @@
"""
Qdrant client for AI Service.
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
Supports multi-dimensional vectors for Matryoshka representation learning.
"""
import logging
from typing import Any
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig
from app.core.config import get_settings
@ -20,6 +21,7 @@ class QdrantClient:
"""
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
Collection naming: kb_{tenantId} for tenant isolation.
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
"""
def __init__(self):
@ -45,13 +47,15 @@ class QdrantClient:
"""
[AC-AISVC-10] Get collection name for a tenant.
Naming convention: kb_{tenantId}
Replaces @ with _ to ensure valid collection names.
"""
return f"{self._collection_prefix}{tenant_id}"
safe_tenant_id = tenant_id.replace('@', '_')
return f"{self._collection_prefix}{safe_tenant_id}"
async def ensure_collection_exists(self, tenant_id: str) -> bool:
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
"""
[AC-AISVC-10] Ensure collection exists for tenant.
Note: MVP uses pre-provisioned collections, this is for development/testing.
Supports multi-dimensional vectors for Matryoshka retrieval.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
@ -61,15 +65,34 @@ class QdrantClient:
exists = any(c.name == collection_name for c in collections.collections)
if not exists:
await client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
if use_multi_vector:
vectors_config = {
"full": VectorParams(
size=768,
distance=Distance.COSINE,
),
"dim_256": VectorParams(
size=256,
distance=Distance.COSINE,
),
"dim_512": VectorParams(
size=512,
distance=Distance.COSINE,
),
}
else:
vectors_config = VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
),
)
await client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id}"
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
f"with multi_vector={use_multi_vector}"
)
return True
except Exception as e:
@ -100,44 +123,160 @@ class QdrantClient:
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
return False
async def upsert_multi_vector(
self,
tenant_id: str,
points: list[dict[str, Any]],
) -> bool:
"""
Upsert points with multi-dimensional vectors.
Args:
tenant_id: Tenant identifier
points: List of points with format:
{
"id": str | int,
"vector": {
"full": [768 floats],
"dim_256": [256 floats],
"dim_512": [512 floats],
},
"payload": dict
}
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
qdrant_points = []
for p in points:
point = PointStruct(
id=p["id"],
vector=p["vector"],
payload=p.get("payload", {}),
)
qdrant_points.append(point)
await client.upsert(
collection_name=collection_name,
points=qdrant_points,
)
logger.info(
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}"
)
return True
except Exception as e:
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
return False
async def search(
self,
tenant_id: str,
query_vector: list[float],
limit: int = 5,
score_threshold: float | None = None,
vector_name: str = "full",
) -> list[dict[str, Any]]:
"""
[AC-AISVC-10] Search vectors in tenant's collection.
Returns results with score >= score_threshold if specified.
Searches both old format (with @) and new format (with _) for backward compatibility.
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
limit: Maximum number of results
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
results = await client.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
logger.info(
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
)
collection_names = [self.get_collection_name(tenant_id)]
if '@' in tenant_id:
old_format = f"{self._collection_prefix}{tenant_id}"
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
collection_names = [new_format, old_format]
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
all_hits = []
for collection_name in collection_names:
try:
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
try:
results = await client.search(
collection_name=collection_name,
query_vector=(vector_name, query_vector),
limit=limit,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
logger.info(
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
f"trying without vector name (single-vector mode)"
)
results = await client.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
)
else:
raise
logger.info(
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
)
hits = [
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
if score_threshold is None or result.score >= score_threshold
]
all_hits.extend(hits)
if hits:
logger.info(
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
)
for i, h in enumerate(hits[:3]):
logger.debug(
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
)
else:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
)
except Exception as e:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
)
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
logger.info(
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
)
if len(all_hits) == 0:
logger.warning(
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
f"collections_tried={collection_names}, limit={limit}"
)
hits = [
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
if score_threshold is None or result.score >= score_threshold
]
logger.info(
f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}"
)
return hits
except Exception as e:
logger.error(f"[AC-AISVC-10] Error searching vectors: {e}")
return []
return all_hits
async def delete_collection(self, tenant_id: str) -> bool:
"""

View File

@ -17,6 +17,11 @@ from app.services.embedding.factory import (
)
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
from app.services.embedding.nomic_provider import (
NomicEmbeddingProvider,
NomicEmbeddingResult,
EmbeddingTask,
)
__all__ = [
"EmbeddingConfig",
@ -29,4 +34,7 @@ __all__ = [
"get_embedding_provider",
"OllamaEmbeddingProvider",
"OpenAIEmbeddingProvider",
"NomicEmbeddingProvider",
"NomicEmbeddingResult",
"EmbeddingTask",
]

View File

@ -13,6 +13,7 @@ from typing import Any, Type
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
logger = logging.getLogger(__name__)
@ -26,6 +27,7 @@ class EmbeddingProviderFactory:
_providers: dict[str, Type[EmbeddingProvider]] = {
"ollama": OllamaEmbeddingProvider,
"openai": OpenAIEmbeddingProvider,
"nomic": NomicEmbeddingProvider,
}
@classmethod
@ -63,11 +65,13 @@ class EmbeddingProviderFactory:
display_names = {
"ollama": "Ollama 本地模型",
"openai": "OpenAI Embedding",
"nomic": "Nomic Embed (优化版)",
}
descriptions = {
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
"openai": "使用 OpenAI 官方 Embedding API支持 text-embedding-3 系列模型",
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断专为RAG优化",
}
return {

View File

@ -0,0 +1,291 @@
"""
Nomic embedding provider with task prefixes and Matryoshka support.
Implements RAG optimization spec:
- Task prefixes: search_document: / search_query:
- Matryoshka dimension truncation: 256/512/768 dimensions
"""
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import httpx
import numpy as np
from app.services.embedding.base import (
EmbeddingConfig,
EmbeddingException,
EmbeddingProvider,
)
logger = logging.getLogger(__name__)
class EmbeddingTask(str, Enum):
"""Task type for nomic-embed-text v1.5 model."""
DOCUMENT = "search_document"
QUERY = "search_query"
@dataclass
class NomicEmbeddingResult:
"""Result from Nomic embedding with multiple dimensions."""
embedding_full: list[float]
embedding_256: list[float]
embedding_512: list[float]
dimension: int
model: str
task: EmbeddingTask
latency_ms: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
class NomicEmbeddingProvider(EmbeddingProvider):
"""
Nomic-embed-text v1.5 embedding provider with task prefixes.
Key features:
- Task prefixes: search_document: for documents, search_query: for queries
- Matryoshka dimension truncation: 256/512/768 dimensions
- Automatic normalization after truncation
Reference: rag-optimization/spec.md Section 2.1, 2.3
"""
PROVIDER_NAME = "nomic"
DOCUMENT_PREFIX = "search_document:"
QUERY_PREFIX = "search_query:"
FULL_DIMENSION = 768
def __init__(
self,
base_url: str = "http://localhost:11434",
model: str = "nomic-embed-text",
dimension: int = 768,
timeout_seconds: int = 60,
enable_matryoshka: bool = True,
**kwargs: Any,
):
self._base_url = base_url.rstrip("/")
self._model = model
self._dimension = dimension
self._timeout = timeout_seconds
self._enable_matryoshka = enable_matryoshka
self._client: httpx.AsyncClient | None = None
self._extra_config = kwargs
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
def _add_prefix(self, text: str, task: EmbeddingTask) -> str:
"""Add task prefix to text."""
if task == EmbeddingTask.DOCUMENT:
prefix = self.DOCUMENT_PREFIX
else:
prefix = self.QUERY_PREFIX
if text.startswith(prefix):
return text
return f"{prefix}{text}"
def _truncate_and_normalize(self, embedding: list[float], target_dim: int) -> list[float]:
"""
Truncate embedding to target dimension and normalize.
Matryoshka representation learning allows dimension truncation.
"""
truncated = embedding[:target_dim]
arr = np.array(truncated, dtype=np.float32)
norm = np.linalg.norm(arr)
if norm > 0:
arr = arr / norm
return arr.tolist()
async def embed_with_task(
self,
text: str,
task: EmbeddingTask,
) -> NomicEmbeddingResult:
"""
Generate embedding with specified task prefix.
Args:
text: Input text to embed
task: DOCUMENT for indexing, QUERY for retrieval
Returns:
NomicEmbeddingResult with all dimension variants
"""
start_time = time.perf_counter()
prefixed_text = self._add_prefix(text, task)
try:
client = await self._get_client()
response = await client.post(
f"{self._base_url}/api/embeddings",
json={
"model": self._model,
"prompt": prefixed_text,
}
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
if not embedding:
raise EmbeddingException(
"Empty embedding returned",
provider=self.PROVIDER_NAME,
details={"text_length": len(text), "task": task.value}
)
latency_ms = (time.perf_counter() - start_time) * 1000
embedding_256 = self._truncate_and_normalize(embedding, 256)
embedding_512 = self._truncate_and_normalize(embedding, 512)
logger.debug(
f"Generated Nomic embedding: task={task.value}, "
f"dim={len(embedding)}, latency={latency_ms:.2f}ms"
)
return NomicEmbeddingResult(
embedding_full=embedding,
embedding_256=embedding_256,
embedding_512=embedding_512,
dimension=len(embedding),
model=self._model,
task=task,
latency_ms=latency_ms,
)
except httpx.HTTPStatusError as e:
raise EmbeddingException(
f"Ollama API error: {e.response.status_code}",
provider=self.PROVIDER_NAME,
details={"status_code": e.response.status_code, "response": e.response.text}
)
except httpx.RequestError as e:
raise EmbeddingException(
f"Ollama connection error: {e}",
provider=self.PROVIDER_NAME,
details={"base_url": self._base_url}
)
except EmbeddingException:
raise
except Exception as e:
raise EmbeddingException(
f"Embedding generation failed: {e}",
provider=self.PROVIDER_NAME
)
async def embed_document(self, text: str) -> NomicEmbeddingResult:
"""
Generate embedding for document (with search_document: prefix).
Use this when indexing documents into vector store.
"""
return await self.embed_with_task(text, EmbeddingTask.DOCUMENT)
async def embed_query(self, text: str) -> NomicEmbeddingResult:
"""
Generate embedding for query (with search_query: prefix).
Use this when searching/retrieving documents.
"""
return await self.embed_with_task(text, EmbeddingTask.QUERY)
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text.
Default uses QUERY task for backward compatibility.
"""
result = await self.embed_query(text)
return result.embedding_full
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
Uses QUERY task by default.
"""
embeddings = []
for text in texts:
embedding = await self.embed(text)
embeddings.append(embedding)
return embeddings
async def embed_documents_batch(
self,
texts: list[str],
) -> list[NomicEmbeddingResult]:
"""
Generate embeddings for multiple documents (DOCUMENT task).
Use this when batch indexing documents.
"""
results = []
for text in texts:
result = await self.embed_document(text)
results.append(result)
return results
async def embed_queries_batch(
self,
texts: list[str],
) -> list[NomicEmbeddingResult]:
"""
Generate embeddings for multiple queries (QUERY task).
Use this when batch processing queries.
"""
results = []
for text in texts:
result = await self.embed_query(text)
results.append(result)
return results
def get_dimension(self) -> int:
"""Get the dimension of embedding vectors."""
return self._dimension
def get_provider_name(self) -> str:
"""Get the name of this embedding provider."""
return self.PROVIDER_NAME
def get_config_schema(self) -> dict[str, Any]:
"""Get the configuration schema for Nomic provider."""
return {
"base_url": {
"type": "string",
"description": "Ollama API 地址",
"default": "http://localhost:11434",
},
"model": {
"type": "string",
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5",
"default": "nomic-embed-text",
},
"dimension": {
"type": "integer",
"description": "向量维度(支持 256/512/768",
"default": 768,
},
"timeout_seconds": {
"type": "integer",
"description": "请求超时时间(秒)",
"default": 60,
},
"enable_matryoshka": {
"type": "boolean",
"description": "启用 Matryoshka 维度截断",
"default": True,
},
}
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None

View File

@ -11,6 +11,11 @@ Design reference: design.md Section 2.2 - 关键数据流
6. compute_confidence(...)
7. Memory.append(tenantId, sessionId, user/assistant messages)
8. Return ChatResponse (or output via SSE)
RAG Optimization (rag-optimization/spec.md):
- Two-stage retrieval with Matryoshka dimensions
- RRF hybrid ranking
- Optimized prompt engineering
"""
import logging
@ -36,6 +41,16 @@ from app.services.retrieval.base import BaseRetriever, RetrievalContext, Retriev
logger = logging.getLogger(__name__)
OPTIMIZED_SYSTEM_PROMPT = """你是学校智能客服助手,基于提供的知识库内容回答用户问题。
回答要求
1. 严格基于提供的知识库内容回答不要编造信息
2. 如果知识库中没有相关信息明确告知用户并建议转人工或稍后重试
3. 保持专业友好的语气回答简洁明了突出重点
4. 如果引用知识库内容请注明来源根据[文档1]...
5. 对于时效性问题请提醒用户注意文档的有效期"""
@dataclass
class OrchestratorConfig:
"""
@ -44,8 +59,9 @@ class OrchestratorConfig:
"""
max_history_tokens: int = 4000
max_evidence_tokens: int = 2000
system_prompt: str = "你是一个智能客服助手,请根据提供的知识库内容回答用户问题。"
system_prompt: str = OPTIMIZED_SYSTEM_PROMPT
enable_rag: bool = True
use_optimized_retriever: bool = True
@dataclass
@ -141,7 +157,14 @@ class OrchestratorService:
"""
logger.info(
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
f"session={request.session_id}"
f"session={request.session_id}, channel_type={request.channel_type}, "
f"current_message={request.current_message[:100]}..."
)
logger.info(
f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, "
f"use_optimized_retriever={self._config.use_optimized_retriever}, "
f"llm_client={'configured' if self._llm_client else 'NOT configured'}, "
f"retriever={'configured' if self._retriever else 'NOT configured'}"
)
ctx = GenerationContext(
@ -257,6 +280,10 @@ class OrchestratorService:
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
Step 3 of the generation pipeline.
"""
logger.info(
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
)
try:
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
@ -277,11 +304,19 @@ class OrchestratorService:
logger.info(
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
f"hits={ctx.retrieval_result.hit_count}, "
f"max_score={ctx.retrieval_result.max_score:.3f}"
f"max_score={ctx.retrieval_result.max_score:.3f}, "
f"is_empty={ctx.retrieval_result.is_empty}"
)
if ctx.retrieval_result.hit_count > 0:
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
logger.info(
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
f"text_preview={hit.text[:100]}..."
)
except Exception as e:
logger.warning(f"[AC-AISVC-16] Retrieval failed: {e}")
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
ctx.retrieval_result = RetrievalResult(
hits=[],
diagnostics={"error": str(e)},
@ -294,9 +329,18 @@ class OrchestratorService:
Step 4-5 of the generation pipeline.
"""
messages = self._build_llm_messages(ctx)
logger.info(
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
f"has_retrieval_result={ctx.retrieval_result is not None}, "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
)
if not self._llm_client:
logger.warning("[AC-AISVC-02] No LLM client configured, using fallback")
logger.warning(
f"[AC-AISVC-02] No LLM client configured, using fallback. "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
)
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
@ -304,6 +348,7 @@ class OrchestratorService:
finish_reason="fallback",
)
ctx.diagnostics["llm_mode"] = "fallback"
ctx.diagnostics["fallback_reason"] = "no_llm_client"
return
try:
@ -318,11 +363,16 @@ class OrchestratorService:
logger.info(
f"[AC-AISVC-02] LLM response generated: "
f"model={ctx.llm_response.model}, "
f"tokens={ctx.llm_response.usage}"
f"tokens={ctx.llm_response.usage}, "
f"content_preview={ctx.llm_response.content[:100]}..."
)
except Exception as e:
logger.error(f"[AC-AISVC-02] LLM generation failed: {e}")
logger.error(
f"[AC-AISVC-02] LLM generation failed: {e}, "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
exc_info=True
)
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
@ -331,6 +381,8 @@ class OrchestratorService:
metadata={"error": str(e)},
)
ctx.diagnostics["llm_error"] = str(e)
ctx.diagnostics["llm_mode"] = "fallback"
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
"""
@ -356,12 +408,26 @@ class OrchestratorService:
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
"""
[AC-AISVC-17] Format retrieval hits as evidence text.
Optimized format with source attribution and metadata.
"""
evidence_parts = []
for i, hit in enumerate(retrieval_result.hits[:5], 1):
evidence_parts.append(f"[{i}] (相关度: {hit.score:.2f}) {hit.text}")
metadata = hit.metadata or {}
source = metadata.get("metadata", {}).get("source_doc", "知识库")
category = metadata.get("metadata", {}).get("category", "")
department = metadata.get("metadata", {}).get("department", "")
return "\n".join(evidence_parts)
header = f"[文档{i}]"
if source and source != "知识库":
header += f" 来源:{source}"
if category:
header += f" | 类别:{category}"
if department:
header += f" | 部门:{department}"
evidence_parts.append(f"{header}\n相关度:{hit.score:.2f}\n内容:{hit.text}")
return "\n\n".join(evidence_parts)
def _fallback_response(self, ctx: GenerationContext) -> str:
"""

View File

@ -1,6 +1,7 @@
"""
Retrieval module for AI Service.
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
"""
from app.services.retrieval.base import (
@ -10,6 +11,27 @@ from app.services.retrieval.base import (
RetrievalResult,
)
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
from app.services.retrieval.metadata import (
ChunkMetadata,
ChunkMetadataModel,
MetadataFilter,
KnowledgeChunk,
RetrieveRequest,
RetrieveResult,
RetrievalStrategy,
)
from app.services.retrieval.optimized_retriever import (
OptimizedRetriever,
get_optimized_retriever,
TwoStageResult,
RRFCombiner,
)
from app.services.retrieval.indexer import (
KnowledgeIndexer,
get_knowledge_indexer,
IndexingProgress,
IndexingResult,
)
__all__ = [
"BaseRetriever",
@ -18,4 +40,18 @@ __all__ = [
"RetrievalResult",
"VectorRetriever",
"get_vector_retriever",
"ChunkMetadata",
"MetadataFilter",
"KnowledgeChunk",
"RetrieveRequest",
"RetrieveResult",
"RetrievalStrategy",
"OptimizedRetriever",
"get_optimized_retriever",
"TwoStageResult",
"RRFCombiner",
"KnowledgeIndexer",
"get_knowledge_indexer",
"IndexingProgress",
"IndexingResult",
]

View File

@ -0,0 +1,339 @@
"""
Knowledge base indexing service with optimized embedding.
Reference: rag-optimization/spec.md Section 5.1
"""
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
from app.services.retrieval.metadata import ChunkMetadata, KnowledgeChunk
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class IndexingProgress:
"""Progress tracking for indexing jobs."""
total_chunks: int = 0
processed_chunks: int = 0
failed_chunks: int = 0
current_document: str = ""
started_at: datetime = field(default_factory=datetime.utcnow)
@property
def progress_percent(self) -> int:
if self.total_chunks == 0:
return 0
return int((self.processed_chunks / self.total_chunks) * 100)
@property
def elapsed_seconds(self) -> float:
return (datetime.utcnow() - self.started_at).total_seconds()
@dataclass
class IndexingResult:
"""Result of an indexing operation."""
success: bool
total_chunks: int
indexed_chunks: int
failed_chunks: int
elapsed_seconds: float
error_message: str | None = None
class KnowledgeIndexer:
"""
Knowledge base indexer with optimized embedding.
Features:
- Task prefixes (search_document:) for document embedding
- Multi-dimensional vectors (256/512/768)
- Metadata support
- Batch processing
"""
def __init__(
self,
qdrant_client: QdrantClient | None = None,
embedding_provider: NomicEmbeddingProvider | None = None,
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 10,
):
self._qdrant_client = qdrant_client
self._embedding_provider = embedding_provider
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._batch_size = batch_size
self._progress: IndexingProgress | None = None
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
self._embedding_provider = NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
return self._embedding_provider
def chunk_text(self, text: str, metadata: ChunkMetadata | None = None) -> list[KnowledgeChunk]:
"""
Split text into chunks for indexing.
Each line becomes a separate chunk for better retrieval granularity.
Args:
text: Full text to chunk
metadata: Metadata to attach to each chunk
Returns:
List of KnowledgeChunk objects
"""
chunks = []
doc_id = str(uuid.uuid4())
lines = text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if len(line) < 10:
continue
chunk = KnowledgeChunk(
chunk_id=f"{doc_id}_{i}",
document_id=doc_id,
content=line,
metadata=metadata or ChunkMetadata(),
)
chunks.append(chunk)
return chunks
def chunk_text_by_lines(
self,
text: str,
metadata: ChunkMetadata | None = None,
min_line_length: int = 10,
merge_short_lines: bool = False,
) -> list[KnowledgeChunk]:
"""
Split text by lines, each line is a separate chunk.
Args:
text: Full text to chunk
metadata: Metadata to attach to each chunk
min_line_length: Minimum line length to be indexed
merge_short_lines: Whether to merge consecutive short lines
Returns:
List of KnowledgeChunk objects
"""
chunks = []
doc_id = str(uuid.uuid4())
lines = text.split('\n')
if merge_short_lines:
merged_lines = []
current_line = ""
for line in lines:
line = line.strip()
if not line:
if current_line:
merged_lines.append(current_line)
current_line = ""
continue
if current_line:
current_line += " " + line
else:
current_line = line
if len(current_line) >= min_line_length * 2:
merged_lines.append(current_line)
current_line = ""
if current_line:
merged_lines.append(current_line)
lines = merged_lines
for i, line in enumerate(lines):
line = line.strip()
if len(line) < min_line_length:
continue
chunk = KnowledgeChunk(
chunk_id=f"{doc_id}_{i}",
document_id=doc_id,
content=line,
metadata=metadata or ChunkMetadata(),
)
chunks.append(chunk)
return chunks
async def index_document(
self,
tenant_id: str,
document_id: str,
text: str,
metadata: ChunkMetadata | None = None,
) -> IndexingResult:
"""
Index a single document with optimized embedding.
Args:
tenant_id: Tenant identifier
document_id: Document identifier
text: Document text content
metadata: Optional metadata for the document
Returns:
IndexingResult with status and statistics
"""
start_time = datetime.utcnow()
try:
client = await self._get_client()
provider = await self._get_embedding_provider()
await client.ensure_collection_exists(tenant_id, use_multi_vector=True)
chunks = self.chunk_text(text, metadata)
self._progress = IndexingProgress(
total_chunks=len(chunks),
current_document=document_id,
)
points = []
for i, chunk in enumerate(chunks):
try:
embedding_result = await provider.embed_document(chunk.content)
chunk.embedding_full = embedding_result.embedding_full
chunk.embedding_256 = embedding_result.embedding_256
chunk.embedding_512 = embedding_result.embedding_512
point = {
"id": str(uuid.uuid4()), # Generate a valid UUID for Qdrant
"vector": {
"full": chunk.embedding_full,
"dim_256": chunk.embedding_256,
"dim_512": chunk.embedding_512,
},
"payload": {
"chunk_id": chunk.chunk_id,
"document_id": document_id,
"text": chunk.content,
"metadata": chunk.metadata.to_dict(),
"created_at": chunk.created_at.isoformat(),
}
}
points.append(point)
self._progress.processed_chunks += 1
logger.debug(
f"[RAG-OPT] Indexed chunk {i+1}/{len(chunks)} for doc={document_id}"
)
except Exception as e:
logger.warning(f"[RAG-OPT] Failed to index chunk {i}: {e}")
self._progress.failed_chunks += 1
if points:
await client.upsert_multi_vector(tenant_id, points)
elapsed = (datetime.utcnow() - start_time).total_seconds()
logger.info(
f"[RAG-OPT] Indexed document {document_id}: "
f"{len(points)} chunks in {elapsed:.2f}s"
)
return IndexingResult(
success=True,
total_chunks=len(chunks),
indexed_chunks=len(points),
failed_chunks=self._progress.failed_chunks,
elapsed_seconds=elapsed,
)
except Exception as e:
elapsed = (datetime.utcnow() - start_time).total_seconds()
logger.error(f"[RAG-OPT] Failed to index document {document_id}: {e}")
return IndexingResult(
success=False,
total_chunks=0,
indexed_chunks=0,
failed_chunks=0,
elapsed_seconds=elapsed,
error_message=str(e),
)
async def index_documents_batch(
self,
tenant_id: str,
documents: list[dict[str, Any]],
) -> list[IndexingResult]:
"""
Index multiple documents in batch.
Args:
tenant_id: Tenant identifier
documents: List of documents with format:
{
"document_id": str,
"text": str,
"metadata": ChunkMetadata (optional)
}
Returns:
List of IndexingResult for each document
"""
results = []
for doc in documents:
result = await self.index_document(
tenant_id=tenant_id,
document_id=doc["document_id"],
text=doc["text"],
metadata=doc.get("metadata"),
)
results.append(result)
return results
def get_progress(self) -> IndexingProgress | None:
"""Get current indexing progress."""
return self._progress
_knowledge_indexer: KnowledgeIndexer | None = None
def get_knowledge_indexer() -> KnowledgeIndexer:
"""Get or create KnowledgeIndexer instance."""
global _knowledge_indexer
if _knowledge_indexer is None:
_knowledge_indexer = KnowledgeIndexer()
return _knowledge_indexer

View File

@ -0,0 +1,210 @@
"""
Metadata models for RAG optimization.
Implements structured metadata for knowledge chunks.
Reference: rag-optimization/spec.md Section 3.2
"""
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
class RetrievalStrategy(str, Enum):
"""Retrieval strategy options."""
VECTOR_ONLY = "vector"
BM25_ONLY = "bm25"
HYBRID = "hybrid"
TWO_STAGE = "two_stage"
class ChunkMetadataModel(BaseModel):
"""Pydantic model for API serialization."""
category: str = ""
subcategory: str = ""
target_audience: list[str] = []
source_doc: str = ""
source_url: str = ""
department: str = ""
valid_from: str | None = None
valid_until: str | None = None
priority: int = 5
keywords: list[str] = []
@dataclass
class ChunkMetadata:
"""
Metadata for knowledge chunks.
Reference: rag-optimization/spec.md Section 3.2.2
"""
category: str = ""
subcategory: str = ""
target_audience: list[str] = field(default_factory=list)
source_doc: str = ""
source_url: str = ""
department: str = ""
valid_from: date | None = None
valid_until: date | None = None
priority: int = 5
keywords: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for storage."""
return {
"category": self.category,
"subcategory": self.subcategory,
"target_audience": self.target_audience,
"source_doc": self.source_doc,
"source_url": self.source_url,
"department": self.department,
"valid_from": self.valid_from.isoformat() if self.valid_from else None,
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
"priority": self.priority,
"keywords": self.keywords,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata":
"""Create from dictionary."""
return cls(
category=data.get("category", ""),
subcategory=data.get("subcategory", ""),
target_audience=data.get("target_audience", []),
source_doc=data.get("source_doc", ""),
source_url=data.get("source_url", ""),
department=data.get("department", ""),
valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None,
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
priority=data.get("priority", 5),
keywords=data.get("keywords", []),
)
@dataclass
class MetadataFilter:
"""
Filter conditions for metadata-based retrieval.
Reference: rag-optimization/spec.md Section 4.1
"""
categories: list[str] | None = None
target_audiences: list[str] | None = None
departments: list[str] | None = None
valid_only: bool = True
min_priority: int | None = None
keywords: list[str] | None = None
def to_qdrant_filter(self) -> dict[str, Any] | None:
"""Convert to Qdrant filter format."""
conditions = []
if self.categories:
conditions.append({
"key": "metadata.category",
"match": {"any": self.categories}
})
if self.departments:
conditions.append({
"key": "metadata.department",
"match": {"any": self.departments}
})
if self.target_audiences:
conditions.append({
"key": "metadata.target_audience",
"match": {"any": self.target_audiences}
})
if self.valid_only:
today = date.today().isoformat()
conditions.append({
"should": [
{"key": "metadata.valid_until", "match": {"value": None}},
{"key": "metadata.valid_until", "range": {"gte": today}}
]
})
if self.min_priority is not None:
conditions.append({
"key": "metadata.priority",
"range": {"lte": self.min_priority}
})
if not conditions:
return None
if len(conditions) == 1:
return {"must": conditions}
return {"must": conditions}
@dataclass
class KnowledgeChunk:
"""
Knowledge chunk with multi-dimensional embeddings.
Reference: rag-optimization/spec.md Section 3.2.1
"""
chunk_id: str
document_id: str
content: str
embedding_full: list[float] = field(default_factory=list)
embedding_256: list[float] = field(default_factory=list)
embedding_512: list[float] = field(default_factory=list)
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]:
"""Convert to Qdrant point format."""
return {
"id": point_id,
"vector": {
"full": self.embedding_full,
"dim_256": self.embedding_256,
"dim_512": self.embedding_512,
},
"payload": {
"chunk_id": self.chunk_id,
"document_id": self.document_id,
"text": self.content,
"metadata": self.metadata.to_dict(),
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
}
@dataclass
class RetrieveRequest:
"""
Request for knowledge retrieval.
Reference: rag-optimization/spec.md Section 4.1
"""
query: str
query_with_prefix: str = ""
top_k: int = 10
filters: MetadataFilter | None = None
strategy: RetrievalStrategy = RetrievalStrategy.HYBRID
def __post_init__(self):
if not self.query_with_prefix:
self.query_with_prefix = f"search_query:{self.query}"
@dataclass
class RetrieveResult:
"""
Result from knowledge retrieval.
Reference: rag-optimization/spec.md Section 4.1
"""
chunk_id: str
content: str
score: float
vector_score: float = 0.0
bm25_score: float = 0.0
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
rank: int = 0

View File

@ -0,0 +1,509 @@
"""
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
"""
import asyncio
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalHit,
RetrievalResult,
)
from app.services.retrieval.metadata import (
ChunkMetadata,
MetadataFilter,
RetrieveResult,
RetrievalStrategy,
)
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class TwoStageResult:
"""Result from two-stage retrieval."""
candidates: list[dict[str, Any]]
final_results: list[RetrieveResult]
stage1_latency_ms: float = 0.0
stage2_latency_ms: float = 0.0
class RRFCombiner:
"""
Reciprocal Rank Fusion for combining multiple retrieval results.
Reference: rag-optimization/spec.md Section 2.5
Formula: score = Σ(1 / (k + rank_i))
Default k = 60
"""
def __init__(self, k: int = 60):
self._k = k
def combine(
self,
vector_results: list[dict[str, Any]],
bm25_results: list[dict[str, Any]],
vector_weight: float = 0.7,
bm25_weight: float = 0.3,
) -> list[dict[str, Any]]:
"""
Combine vector and BM25 results using RRF.
Args:
vector_results: Results from vector search
bm25_results: Results from BM25 search
vector_weight: Weight for vector results
bm25_weight: Weight for BM25 results
Returns:
Combined and sorted results
"""
combined_scores: dict[str, dict[str, Any]] = {}
for rank, result in enumerate(vector_results):
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
rrf_score = vector_weight / (self._k + rank + 1)
if chunk_id not in combined_scores:
combined_scores[chunk_id] = {
"score": 0.0,
"vector_score": result.get("score", 0.0),
"bm25_score": 0.0,
"vector_rank": rank,
"bm25_rank": -1,
"payload": result.get("payload", {}),
"id": chunk_id,
}
combined_scores[chunk_id]["score"] += rrf_score
for rank, result in enumerate(bm25_results):
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
rrf_score = bm25_weight / (self._k + rank + 1)
if chunk_id not in combined_scores:
combined_scores[chunk_id] = {
"score": 0.0,
"vector_score": 0.0,
"bm25_score": result.get("score", 0.0),
"vector_rank": -1,
"bm25_rank": rank,
"payload": result.get("payload", {}),
"id": chunk_id,
}
else:
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
combined_scores[chunk_id]["bm25_rank"] = rank
combined_scores[chunk_id]["score"] += rrf_score
sorted_results = sorted(
combined_scores.values(),
key=lambda x: x["score"],
reverse=True
)
return sorted_results
class OptimizedRetriever(BaseRetriever):
"""
Optimized retriever with:
- Task prefixes (search_document/search_query)
- Two-stage retrieval (256 dim -> 768 dim)
- RRF hybrid ranking (vector + BM25)
- Metadata filtering
Reference: rag-optimization/spec.md Section 2, 3, 4
"""
def __init__(
self,
qdrant_client: QdrantClient | None = None,
embedding_provider: NomicEmbeddingProvider | None = None,
top_k: int | None = None,
score_threshold: float | None = None,
min_hits: int | None = None,
two_stage_enabled: bool | None = None,
two_stage_expand_factor: int | None = None,
hybrid_enabled: bool | None = None,
rrf_k: int | None = None,
):
self._qdrant_client = qdrant_client
self._embedding_provider = embedding_provider
self._top_k = top_k or settings.rag_top_k
self._score_threshold = score_threshold or settings.rag_score_threshold
self._min_hits = min_hits or settings.rag_min_hits
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
self._rrf_k = rrf_k or settings.rag_rrf_k
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
from app.services.embedding.factory import get_embedding_config_manager
manager = get_embedding_config_manager()
provider = await manager.get_provider()
if isinstance(provider, NomicEmbeddingProvider):
self._embedding_provider = provider
else:
self._embedding_provider = NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
return self._embedding_provider
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
"""
Retrieve documents using optimized strategy.
Strategy selection:
1. If two_stage_enabled: use two-stage retrieval
2. If hybrid_enabled: use RRF hybrid ranking
3. Otherwise: simple vector search
"""
logger.info(
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
)
logger.info(
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
)
try:
provider = await self._get_embedding_provider()
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
embedding_result = await provider.embed_query(ctx.query)
logger.info(
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
)
if self._two_stage_enabled:
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
results = await self._two_stage_retrieve(
ctx.tenant_id,
embedding_result,
self._top_k,
)
elif self._hybrid_enabled:
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
results = await self._hybrid_retrieve(
ctx.tenant_id,
embedding_result,
ctx.query,
self._top_k,
)
else:
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
results = await self._vector_retrieve(
ctx.tenant_id,
embedding_result.embedding_full,
self._top_k,
)
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
retrieval_hits = [
RetrievalHit(
text=result.get("payload", {}).get("text", ""),
score=result.get("score", 0.0),
source="optimized_rag",
metadata=result.get("payload", {}),
)
for result in results
if result.get("score", 0.0) >= self._score_threshold
]
filtered_count = len(results) - len(retrieval_hits)
if filtered_count > 0:
logger.info(
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
)
is_insufficient = len(retrieval_hits) < self._min_hits
diagnostics = {
"query_length": len(ctx.query),
"top_k": self._top_k,
"score_threshold": self._score_threshold,
"two_stage_enabled": self._two_stage_enabled,
"hybrid_enabled": self._hybrid_enabled,
"total_hits": len(retrieval_hits),
"is_insufficient": is_insufficient,
"max_score": max((h.score for h in retrieval_hits), default=0.0),
"raw_results_count": len(results),
"filtered_below_threshold": filtered_count,
}
logger.info(
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
)
if len(retrieval_hits) == 0:
logger.warning(
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
f"raw_results={len(results)}, threshold={self._score_threshold}"
)
return RetrievalResult(
hits=retrieval_hits,
diagnostics=diagnostics,
)
except Exception as e:
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
return RetrievalResult(
hits=[],
diagnostics={"error": str(e), "is_insufficient": True},
)
async def _two_stage_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
top_k: int,
) -> list[dict[str, Any]]:
"""
Two-stage retrieval using Matryoshka dimensions.
Stage 1: Fast retrieval with 256-dim vectors
Stage 2: Precise reranking with 768-dim vectors
Reference: rag-optimization/spec.md Section 2.4
"""
import time
client = await self._get_client()
stage1_start = time.perf_counter()
candidates = await self._search_with_dimension(
client, tenant_id, embedding_result.embedding_256, "dim_256",
top_k * self._two_stage_expand_factor
)
stage1_latency = (time.perf_counter() - stage1_start) * 1000
logger.debug(
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
)
stage2_start = time.perf_counter()
reranked = []
for candidate in candidates:
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
if stored_full_embedding:
import numpy as np
similarity = self._cosine_similarity(
embedding_result.embedding_full,
stored_full_embedding
)
candidate["score"] = similarity
candidate["stage"] = "reranked"
reranked.append(candidate)
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
results = reranked[:top_k]
stage2_latency = (time.perf_counter() - stage2_start) * 1000
logger.debug(
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
)
return results
async def _hybrid_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
query: str,
top_k: int,
) -> list[dict[str, Any]]:
"""
Hybrid retrieval using RRF to combine vector and BM25 results.
Reference: rag-optimization/spec.md Section 2.5
"""
client = await self._get_client()
vector_task = self._search_with_dimension(
client, tenant_id, embedding_result.embedding_full, "full",
top_k * 2
)
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
vector_results, bm25_results = await asyncio.gather(
vector_task, bm25_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
vector_results = []
if isinstance(bm25_results, Exception):
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
bm25_results = []
combined = self._rrf_combiner.combine(
vector_results,
bm25_results,
vector_weight=settings.rag_vector_weight,
bm25_weight=settings.rag_bm25_weight,
)
return combined[:top_k]
async def _vector_retrieve(
self,
tenant_id: str,
embedding: list[float],
top_k: int,
) -> list[dict[str, Any]]:
"""Simple vector retrieval."""
client = await self._get_client()
return await self._search_with_dimension(
client, tenant_id, embedding, "full", top_k
)
async def _search_with_dimension(
self,
client: QdrantClient,
tenant_id: str,
query_vector: list[float],
vector_name: str,
limit: int,
) -> list[dict[str, Any]]:
"""Search using specified vector dimension."""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
logger.info(
f"[RAG-OPT] Searching collection={collection_name}, "
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
)
results = await qdrant.search(
collection_name=collection_name,
query_vector=(vector_name, query_vector),
limit=limit,
)
logger.info(
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
)
if len(results) > 0:
for i, r in enumerate(results[:3]):
logger.debug(
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
)
return [
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
]
except Exception as e:
logger.error(
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
f"collection_name={client.get_collection_name(tenant_id)}",
exc_info=True
)
return []
async def _bm25_search(
self,
client: QdrantClient,
tenant_id: str,
query: str,
limit: int,
) -> list[dict[str, Any]]:
"""
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
This is a simplified implementation; for production, use Elasticsearch.
"""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
query_terms = set(re.findall(r'\w+', query.lower()))
results = await qdrant.scroll(
collection_name=collection_name,
limit=limit * 3,
with_payload=True,
)
scored_results = []
for point in results[0]:
text = point.payload.get("text", "").lower()
text_terms = set(re.findall(r'\w+', text))
overlap = len(query_terms & text_terms)
if overlap > 0:
score = overlap / (len(query_terms) + len(text_terms) - overlap)
scored_results.append({
"id": str(point.id),
"score": score,
"payload": point.payload or {},
})
scored_results.sort(key=lambda x: x["score"], reverse=True)
return scored_results[:limit]
except Exception as e:
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
return []
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
import numpy as np
a = np.array(vec1)
b = np.array(vec2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
async def health_check(self) -> bool:
"""Check if retriever is healthy."""
try:
client = await self._get_client()
qdrant = await client.get_client()
await qdrant.get_collections()
return True
except Exception as e:
logger.error(f"[RAG-OPT] Health check failed: {e}")
return False
_optimized_retriever: OptimizedRetriever | None = None
async def get_optimized_retriever() -> OptimizedRetriever:
"""Get or create OptimizedRetriever instance."""
global _optimized_retriever
if _optimized_retriever is None:
_optimized_retriever = OptimizedRetriever()
return _optimized_retriever