feat: RAG 检索优化,实现多维度向量存储和 Nomic 嵌入提供者 [AC-AISVC-16, AC-AISVC-29]
This commit is contained in:
parent
774744d534
commit
cee884d9a0
|
|
@ -0,0 +1,330 @@
|
||||||
|
"""
|
||||||
|
Knowledge base management API with RAG optimization features.
|
||||||
|
Reference: rag-optimization/spec.md Section 4.2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import date
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database import get_session
|
||||||
|
from app.services.retrieval import (
|
||||||
|
ChunkMetadata,
|
||||||
|
ChunkMetadataModel,
|
||||||
|
IndexingProgress,
|
||||||
|
IndexingResult,
|
||||||
|
KnowledgeIndexer,
|
||||||
|
MetadataFilter,
|
||||||
|
RetrievalStrategy,
|
||||||
|
get_knowledge_indexer,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/kb", tags=["Knowledge Base"])
|
||||||
|
|
||||||
|
|
||||||
|
class IndexDocumentRequest(BaseModel):
|
||||||
|
"""Request to index a document."""
|
||||||
|
tenant_id: str = Field(..., description="Tenant ID")
|
||||||
|
document_id: str = Field(..., description="Document ID")
|
||||||
|
text: str = Field(..., description="Document text content")
|
||||||
|
metadata: ChunkMetadataModel | None = Field(default=None, description="Document metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class IndexDocumentResponse(BaseModel):
|
||||||
|
"""Response from document indexing."""
|
||||||
|
success: bool
|
||||||
|
total_chunks: int
|
||||||
|
indexed_chunks: int
|
||||||
|
failed_chunks: int
|
||||||
|
elapsed_seconds: float
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IndexingProgressResponse(BaseModel):
|
||||||
|
"""Response with current indexing progress."""
|
||||||
|
total_chunks: int
|
||||||
|
processed_chunks: int
|
||||||
|
failed_chunks: int
|
||||||
|
progress_percent: int
|
||||||
|
elapsed_seconds: float
|
||||||
|
current_document: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilterRequest(BaseModel):
|
||||||
|
"""Request for metadata filtering."""
|
||||||
|
categories: list[str] | None = None
|
||||||
|
target_audiences: list[str] | None = None
|
||||||
|
departments: list[str] | None = None
|
||||||
|
valid_only: bool = True
|
||||||
|
min_priority: int | None = None
|
||||||
|
keywords: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveRequest(BaseModel):
|
||||||
|
"""Request for knowledge retrieval."""
|
||||||
|
tenant_id: str = Field(..., description="Tenant ID")
|
||||||
|
query: str = Field(..., description="Search query")
|
||||||
|
top_k: int = Field(default=10, ge=1, le=50, description="Number of results")
|
||||||
|
filters: MetadataFilterRequest | None = Field(default=None, description="Metadata filters")
|
||||||
|
strategy: RetrievalStrategy = Field(default=RetrievalStrategy.HYBRID, description="Retrieval strategy")
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveResponse(BaseModel):
|
||||||
|
"""Response from knowledge retrieval."""
|
||||||
|
hits: list[dict[str, Any]]
|
||||||
|
total_hits: int
|
||||||
|
max_score: float
|
||||||
|
is_insufficient: bool
|
||||||
|
diagnostics: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataOptionsResponse(BaseModel):
|
||||||
|
"""Response with available metadata options."""
|
||||||
|
categories: list[str]
|
||||||
|
departments: list[str]
|
||||||
|
target_audiences: list[str]
|
||||||
|
priorities: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/index", response_model=IndexDocumentResponse)
|
||||||
|
async def index_document(
|
||||||
|
request: IndexDocumentRequest,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Index a document with optimized embedding.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Task prefixes (search_document:) for document embedding
|
||||||
|
- Multi-dimensional vectors (256/512/768)
|
||||||
|
- Metadata support
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
index = get_knowledge_indexer()
|
||||||
|
|
||||||
|
chunk_metadata = None
|
||||||
|
if request.metadata:
|
||||||
|
chunk_metadata = ChunkMetadata(
|
||||||
|
category=request.metadata.category,
|
||||||
|
subcategory=request.metadata.subcategory,
|
||||||
|
target_audience=request.metadata.target_audience,
|
||||||
|
source_doc=request.metadata.source_doc,
|
||||||
|
source_url=request.metadata.source_url,
|
||||||
|
department=request.metadata.department,
|
||||||
|
priority=request.metadata.priority,
|
||||||
|
keywords=request.metadata.keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await index.index_document(
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
document_id=request.document_id,
|
||||||
|
text=request.text,
|
||||||
|
metadata=chunk_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return IndexDocumentResponse(
|
||||||
|
success=result.success,
|
||||||
|
total_chunks=result.total_chunks,
|
||||||
|
indexed_chunks=result.indexed_chunks,
|
||||||
|
failed_chunks=result.failed_chunks,
|
||||||
|
elapsed_seconds=result.elapsed_seconds,
|
||||||
|
error_message=result.error_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[KB-API] Failed to index document: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"索引失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/index/progress", response_model=IndexingProgressResponse | None)
|
||||||
|
async def get_indexing_progress():
|
||||||
|
"""Get current indexing progress."""
|
||||||
|
try:
|
||||||
|
index = get_knowledge_indexer()
|
||||||
|
progress = index.get_progress()
|
||||||
|
|
||||||
|
if progress is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return IndexingProgressResponse(
|
||||||
|
total_chunks=progress.total_chunks,
|
||||||
|
processed_chunks=progress.processed_chunks,
|
||||||
|
failed_chunks=progress.failed_chunks,
|
||||||
|
progress_percent=progress.progress_percent,
|
||||||
|
elapsed_seconds=progress.elapsed_seconds,
|
||||||
|
current_document=progress.current_document,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[KB-API] Failed to get progress: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"获取进度失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/retrieve", response_model=RetrieveResponse)
|
||||||
|
async def retrieve_knowledge(request: RetrieveRequest):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge using optimized RAG.
|
||||||
|
|
||||||
|
Strategies:
|
||||||
|
- vector: Simple vector search
|
||||||
|
- bm25: BM25 keyword search
|
||||||
|
- hybrid: RRF combination of vector + BM25 (default)
|
||||||
|
- two_stage: Two-stage retrieval with Matryoshka dimensions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||||
|
from app.services.retrieval.base import RetrievalContext
|
||||||
|
|
||||||
|
retriever = await get_optimized_retriever()
|
||||||
|
|
||||||
|
metadata_filter = None
|
||||||
|
if request.filters:
|
||||||
|
filter_dict = request.filters.model_dump(exclude_none=True)
|
||||||
|
metadata_filter = MetadataFilter(**filter_dict)
|
||||||
|
|
||||||
|
ctx = RetrievalContext(
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
query=request.query,
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata_filter:
|
||||||
|
ctx.metadata = {"filter": metadata_filter.to_qdrant_filter()}
|
||||||
|
|
||||||
|
result = await retriever.retrieve(ctx)
|
||||||
|
|
||||||
|
return RetrieveResponse(
|
||||||
|
hits=[
|
||||||
|
{
|
||||||
|
"text": hit.text,
|
||||||
|
"score": hit.score,
|
||||||
|
"source": hit.source,
|
||||||
|
"metadata": hit.metadata,
|
||||||
|
}
|
||||||
|
for hit in result.hits
|
||||||
|
],
|
||||||
|
total_hits=result.hit_count,
|
||||||
|
max_score=result.max_score,
|
||||||
|
is_insufficient=result.diagnostics.get("is_insufficient", False),
|
||||||
|
diagnostics=result.diagnostics or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[KB-API] Failed to retrieve: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"检索失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metadata/options", response_model=MetadataOptionsResponse)
|
||||||
|
async def get_metadata_options():
|
||||||
|
"""
|
||||||
|
Get available metadata options for filtering.
|
||||||
|
These would typically be loaded from a database.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return MetadataOptionsResponse(
|
||||||
|
categories=[
|
||||||
|
"课程咨询",
|
||||||
|
"考试政策",
|
||||||
|
"学籍管理",
|
||||||
|
"奖助学金",
|
||||||
|
"宿舍管理",
|
||||||
|
"校园服务",
|
||||||
|
"就业指导",
|
||||||
|
"其他",
|
||||||
|
],
|
||||||
|
departments=[
|
||||||
|
"教务处",
|
||||||
|
"学生处",
|
||||||
|
"财务处",
|
||||||
|
"后勤处",
|
||||||
|
"就业指导中心",
|
||||||
|
"图书馆",
|
||||||
|
"信息中心",
|
||||||
|
],
|
||||||
|
target_audiences=[
|
||||||
|
"本科生",
|
||||||
|
"研究生",
|
||||||
|
"留学生",
|
||||||
|
"新生",
|
||||||
|
"毕业生",
|
||||||
|
"教职工",
|
||||||
|
],
|
||||||
|
priorities=list(range(1, 11)),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[KB-API] Failed to get metadata options: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"获取选项失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reindex")
|
||||||
|
async def reindex_all(
|
||||||
|
tenant_id: str,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reindex all documents for a tenant with optimized embedding.
|
||||||
|
This would typically read from the documents table and reindex.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.models.entities import Document, DocumentStatus
|
||||||
|
|
||||||
|
stmt = select(Document).where(
|
||||||
|
Document.tenant_id == tenant_id,
|
||||||
|
Document.status == DocumentStatus.COMPLETED.value,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
documents = result.scalars().all()
|
||||||
|
|
||||||
|
index = get_knowledge_indexer()
|
||||||
|
|
||||||
|
total_indexed = 0
|
||||||
|
total_failed = 0
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if doc.file_path:
|
||||||
|
import os
|
||||||
|
if os.path.exists(doc.file_path):
|
||||||
|
with open(doc.file_path, 'r', encoding='utf-8') as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
result = await index.index_document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
document_id=str(doc.id),
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_indexed += result.indexed_chunks
|
||||||
|
total_failed += result.failed_chunks
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"total_documents": len(documents),
|
||||||
|
"total_indexed": total_indexed,
|
||||||
|
"total_failed": total_failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[KB-API] Failed to reindex: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"重新索引失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
@ -17,6 +17,7 @@ from app.core.exceptions import MissingTenantIdException
|
||||||
from app.core.tenant import get_tenant_id
|
from app.core.tenant import get_tenant_id
|
||||||
from app.models import ErrorResponse
|
from app.models import ErrorResponse
|
||||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||||
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||||
from app.services.retrieval.base import RetrievalContext
|
from app.services.retrieval.base import RetrievalContext
|
||||||
from app.services.llm.factory import get_llm_config_manager
|
from app.services.llm.factory import get_llm_config_manager
|
||||||
|
|
||||||
|
|
@ -91,7 +92,8 @@ async def run_rag_experiment(
|
||||||
threshold = request.score_threshold or settings.rag_score_threshold
|
threshold = request.score_threshold or settings.rag_score_threshold
|
||||||
|
|
||||||
try:
|
try:
|
||||||
retriever = await get_vector_retriever()
|
# Use optimized retriever with RAG enhancements
|
||||||
|
retriever = await get_optimized_retriever()
|
||||||
|
|
||||||
retrieval_ctx = RetrievalContext(
|
retrieval_ctx = RetrievalContext(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -199,7 +201,8 @@ async def run_rag_experiment_stream(
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
try:
|
try:
|
||||||
retriever = await get_vector_retriever()
|
# Use optimized retriever with RAG enhancements
|
||||||
|
retriever = await get_optimized_retriever()
|
||||||
|
|
||||||
retrieval_ctx = RetrievalContext(
|
retrieval_ctx = RetrievalContext(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
|
||||||
|
|
@ -9,18 +9,43 @@ from typing import Annotated, Any
|
||||||
from fastapi import APIRouter, Depends, Header, Request
|
from fastapi import APIRouter, Depends, Header, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database import get_session
|
||||||
from app.core.middleware import get_response_mode, is_sse_request
|
from app.core.middleware import get_response_mode, is_sse_request
|
||||||
from app.core.sse import SSEStateMachine, create_error_event
|
from app.core.sse import SSEStateMachine, create_error_event
|
||||||
from app.core.tenant import get_tenant_id
|
from app.core.tenant import get_tenant_id
|
||||||
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
||||||
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
from app.services.memory import MemoryService
|
||||||
|
from app.services.orchestrator import OrchestratorService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["AI Chat"])
|
router = APIRouter(tags=["AI Chat"])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_orchestrator_service_with_memory(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)]
|
||||||
|
) -> OrchestratorService:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-13] Create orchestrator service with memory service and LLM client.
|
||||||
|
Ensures each request has a fresh MemoryService with database session.
|
||||||
|
"""
|
||||||
|
from app.services.llm.factory import get_llm_config_manager
|
||||||
|
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||||
|
|
||||||
|
memory_service = MemoryService(session)
|
||||||
|
llm_config_manager = get_llm_config_manager()
|
||||||
|
llm_client = llm_config_manager.get_client()
|
||||||
|
retriever = await get_vector_retriever()
|
||||||
|
|
||||||
|
return OrchestratorService(
|
||||||
|
llm_client=llm_client,
|
||||||
|
memory_service=memory_service,
|
||||||
|
retriever=retriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/ai/chat",
|
"/ai/chat",
|
||||||
operation_id="generateReply",
|
operation_id="generateReply",
|
||||||
|
|
@ -49,7 +74,7 @@ async def generate_reply(
|
||||||
request: Request,
|
request: Request,
|
||||||
chat_request: ChatRequest,
|
chat_request: ChatRequest,
|
||||||
accept: Annotated[str | None, Header()] = None,
|
accept: Annotated[str | None, Header()] = None,
|
||||||
orchestrator: OrchestratorService = Depends(get_orchestrator_service),
|
orchestrator: OrchestratorService = Depends(get_orchestrator_service_with_memory),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
|
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
"""
|
"""
|
||||||
Qdrant client for AI Service.
|
Qdrant client for AI Service.
|
||||||
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
|
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
|
||||||
|
Supports multi-dimensional vectors for Matryoshka representation learning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from qdrant_client import AsyncQdrantClient
|
from qdrant_client import AsyncQdrantClient
|
||||||
from qdrant_client.models import Distance, PointStruct, VectorParams
|
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
|
@ -20,6 +21,7 @@ class QdrantClient:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
|
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
|
||||||
Collection naming: kb_{tenantId} for tenant isolation.
|
Collection naming: kb_{tenantId} for tenant isolation.
|
||||||
|
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -45,13 +47,15 @@ class QdrantClient:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Get collection name for a tenant.
|
[AC-AISVC-10] Get collection name for a tenant.
|
||||||
Naming convention: kb_{tenantId}
|
Naming convention: kb_{tenantId}
|
||||||
|
Replaces @ with _ to ensure valid collection names.
|
||||||
"""
|
"""
|
||||||
return f"{self._collection_prefix}{tenant_id}"
|
safe_tenant_id = tenant_id.replace('@', '_')
|
||||||
|
return f"{self._collection_prefix}{safe_tenant_id}"
|
||||||
|
|
||||||
async def ensure_collection_exists(self, tenant_id: str) -> bool:
|
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Ensure collection exists for tenant.
|
[AC-AISVC-10] Ensure collection exists for tenant.
|
||||||
Note: MVP uses pre-provisioned collections, this is for development/testing.
|
Supports multi-dimensional vectors for Matryoshka retrieval.
|
||||||
"""
|
"""
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
collection_name = self.get_collection_name(tenant_id)
|
collection_name = self.get_collection_name(tenant_id)
|
||||||
|
|
@ -61,15 +65,34 @@ class QdrantClient:
|
||||||
exists = any(c.name == collection_name for c in collections.collections)
|
exists = any(c.name == collection_name for c in collections.collections)
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
await client.create_collection(
|
if use_multi_vector:
|
||||||
collection_name=collection_name,
|
vectors_config = {
|
||||||
|
"full": VectorParams(
|
||||||
|
size=768,
|
||||||
|
distance=Distance.COSINE,
|
||||||
|
),
|
||||||
|
"dim_256": VectorParams(
|
||||||
|
size=256,
|
||||||
|
distance=Distance.COSINE,
|
||||||
|
),
|
||||||
|
"dim_512": VectorParams(
|
||||||
|
size=512,
|
||||||
|
distance=Distance.COSINE,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
vectors_config = VectorParams(
|
vectors_config = VectorParams(
|
||||||
size=self._vector_size,
|
size=self._vector_size,
|
||||||
distance=Distance.COSINE,
|
distance=Distance.COSINE,
|
||||||
),
|
)
|
||||||
|
|
||||||
|
await client.create_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=vectors_config,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
|
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
|
||||||
|
f"with multi_vector={use_multi_vector}"
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -100,26 +123,117 @@ class QdrantClient:
|
||||||
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
|
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def upsert_multi_vector(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
points: list[dict[str, Any]],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Upsert points with multi-dimensional vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier
|
||||||
|
points: List of points with format:
|
||||||
|
{
|
||||||
|
"id": str | int,
|
||||||
|
"vector": {
|
||||||
|
"full": [768 floats],
|
||||||
|
"dim_256": [256 floats],
|
||||||
|
"dim_512": [512 floats],
|
||||||
|
},
|
||||||
|
"payload": dict
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
client = await self.get_client()
|
||||||
|
collection_name = self.get_collection_name(tenant_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
qdrant_points = []
|
||||||
|
for p in points:
|
||||||
|
point = PointStruct(
|
||||||
|
id=p["id"],
|
||||||
|
vector=p["vector"],
|
||||||
|
payload=p.get("payload", {}),
|
||||||
|
)
|
||||||
|
qdrant_points.append(point)
|
||||||
|
|
||||||
|
await client.upsert(
|
||||||
|
collection_name=collection_name,
|
||||||
|
points=qdrant_points,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
|
vector_name: str = "full",
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||||
Returns results with score >= score_threshold if specified.
|
Returns results with score >= score_threshold if specified.
|
||||||
|
Searches both old format (with @) and new format (with _) for backward compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier
|
||||||
|
query_vector: Query vector for similarity search
|
||||||
|
limit: Maximum number of results
|
||||||
|
score_threshold: Minimum score threshold for results
|
||||||
|
vector_name: Name of the vector to search (for multi-vector collections)
|
||||||
|
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||||
"""
|
"""
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
collection_name = self.get_collection_name(tenant_id)
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
||||||
|
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
collection_names = [self.get_collection_name(tenant_id)]
|
||||||
|
if '@' in tenant_id:
|
||||||
|
old_format = f"{self._collection_prefix}{tenant_id}"
|
||||||
|
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
|
||||||
|
collection_names = [new_format, old_format]
|
||||||
|
|
||||||
|
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
|
||||||
|
|
||||||
|
all_hits = []
|
||||||
|
|
||||||
|
for collection_name in collection_names:
|
||||||
|
try:
|
||||||
|
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
results = await client.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query_vector=(vector_name, query_vector),
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
|
||||||
|
f"trying without vector name (single-vector mode)"
|
||||||
|
)
|
||||||
results = await client.search(
|
results = await client.search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
|
||||||
|
)
|
||||||
|
|
||||||
hits = [
|
hits = [
|
||||||
{
|
{
|
||||||
|
|
@ -130,14 +244,39 @@ class QdrantClient:
|
||||||
for result in results
|
for result in results
|
||||||
if score_threshold is None or result.score >= score_threshold
|
if score_threshold is None or result.score >= score_threshold
|
||||||
]
|
]
|
||||||
|
all_hits.extend(hits)
|
||||||
|
|
||||||
|
if hits:
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
||||||
|
)
|
||||||
|
for i, h in enumerate(hits[:3]):
|
||||||
|
logger.debug(
|
||||||
|
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}"
|
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
||||||
)
|
)
|
||||||
return hits
|
|
||||||
except Exception as e:
|
if len(all_hits) == 0:
|
||||||
logger.error(f"[AC-AISVC-10] Error searching vectors: {e}")
|
logger.warning(
|
||||||
return []
|
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
|
||||||
|
f"collections_tried={collection_names}, limit={limit}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_hits
|
||||||
|
|
||||||
async def delete_collection(self, tenant_id: str) -> bool:
|
async def delete_collection(self, tenant_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,11 @@ from app.services.embedding.factory import (
|
||||||
)
|
)
|
||||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||||
|
from app.services.embedding.nomic_provider import (
|
||||||
|
NomicEmbeddingProvider,
|
||||||
|
NomicEmbeddingResult,
|
||||||
|
EmbeddingTask,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmbeddingConfig",
|
"EmbeddingConfig",
|
||||||
|
|
@ -29,4 +34,7 @@ __all__ = [
|
||||||
"get_embedding_provider",
|
"get_embedding_provider",
|
||||||
"OllamaEmbeddingProvider",
|
"OllamaEmbeddingProvider",
|
||||||
"OpenAIEmbeddingProvider",
|
"OpenAIEmbeddingProvider",
|
||||||
|
"NomicEmbeddingProvider",
|
||||||
|
"NomicEmbeddingResult",
|
||||||
|
"EmbeddingTask",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from typing import Any, Type
|
||||||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||||
|
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -26,6 +27,7 @@ class EmbeddingProviderFactory:
|
||||||
_providers: dict[str, Type[EmbeddingProvider]] = {
|
_providers: dict[str, Type[EmbeddingProvider]] = {
|
||||||
"ollama": OllamaEmbeddingProvider,
|
"ollama": OllamaEmbeddingProvider,
|
||||||
"openai": OpenAIEmbeddingProvider,
|
"openai": OpenAIEmbeddingProvider,
|
||||||
|
"nomic": NomicEmbeddingProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -63,11 +65,13 @@ class EmbeddingProviderFactory:
|
||||||
display_names = {
|
display_names = {
|
||||||
"ollama": "Ollama 本地模型",
|
"ollama": "Ollama 本地模型",
|
||||||
"openai": "OpenAI Embedding",
|
"openai": "OpenAI Embedding",
|
||||||
|
"nomic": "Nomic Embed (优化版)",
|
||||||
}
|
}
|
||||||
|
|
||||||
descriptions = {
|
descriptions = {
|
||||||
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
|
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
|
||||||
"openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型",
|
"openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型",
|
||||||
|
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,291 @@
|
||||||
|
"""
|
||||||
|
Nomic embedding provider with task prefixes and Matryoshka support.
|
||||||
|
Implements RAG optimization spec:
|
||||||
|
- Task prefixes: search_document: / search_query:
|
||||||
|
- Matryoshka dimension truncation: 256/512/768 dimensions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.services.embedding.base import (
|
||||||
|
EmbeddingConfig,
|
||||||
|
EmbeddingException,
|
||||||
|
EmbeddingProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTask(str, Enum):
|
||||||
|
"""Task type for nomic-embed-text v1.5 model."""
|
||||||
|
DOCUMENT = "search_document"
|
||||||
|
QUERY = "search_query"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NomicEmbeddingResult:
|
||||||
|
"""Result from Nomic embedding with multiple dimensions."""
|
||||||
|
embedding_full: list[float]
|
||||||
|
embedding_256: list[float]
|
||||||
|
embedding_512: list[float]
|
||||||
|
dimension: int
|
||||||
|
model: str
|
||||||
|
task: EmbeddingTask
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class NomicEmbeddingProvider(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
Nomic-embed-text v1.5 embedding provider with task prefixes.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Task prefixes: search_document: for documents, search_query: for queries
|
||||||
|
- Matryoshka dimension truncation: 256/512/768 dimensions
|
||||||
|
- Automatic normalization after truncation
|
||||||
|
|
||||||
|
Reference: rag-optimization/spec.md Section 2.1, 2.3
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROVIDER_NAME = "nomic"
|
||||||
|
DOCUMENT_PREFIX = "search_document:"
|
||||||
|
QUERY_PREFIX = "search_query:"
|
||||||
|
FULL_DIMENSION = 768
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "http://localhost:11434",
|
||||||
|
model: str = "nomic-embed-text",
|
||||||
|
dimension: int = 768,
|
||||||
|
timeout_seconds: int = 60,
|
||||||
|
enable_matryoshka: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._model = model
|
||||||
|
self._dimension = dimension
|
||||||
|
self._timeout = timeout_seconds
|
||||||
|
self._enable_matryoshka = enable_matryoshka
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._extra_config = kwargs
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _add_prefix(self, text: str, task: EmbeddingTask) -> str:
|
||||||
|
"""Add task prefix to text."""
|
||||||
|
if task == EmbeddingTask.DOCUMENT:
|
||||||
|
prefix = self.DOCUMENT_PREFIX
|
||||||
|
else:
|
||||||
|
prefix = self.QUERY_PREFIX
|
||||||
|
|
||||||
|
if text.startswith(prefix):
|
||||||
|
return text
|
||||||
|
return f"{prefix}{text}"
|
||||||
|
|
||||||
|
def _truncate_and_normalize(self, embedding: list[float], target_dim: int) -> list[float]:
|
||||||
|
"""
|
||||||
|
Truncate embedding to target dimension and normalize.
|
||||||
|
Matryoshka representation learning allows dimension truncation.
|
||||||
|
"""
|
||||||
|
truncated = embedding[:target_dim]
|
||||||
|
|
||||||
|
arr = np.array(truncated, dtype=np.float32)
|
||||||
|
norm = np.linalg.norm(arr)
|
||||||
|
if norm > 0:
|
||||||
|
arr = arr / norm
|
||||||
|
|
||||||
|
return arr.tolist()
|
||||||
|
|
||||||
|
async def embed_with_task(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
task: EmbeddingTask,
|
||||||
|
) -> NomicEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Generate embedding with specified task prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to embed
|
||||||
|
task: DOCUMENT for indexing, QUERY for retrieval
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NomicEmbeddingResult with all dimension variants
|
||||||
|
"""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
prefixed_text = self._add_prefix(text, task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
response = await client.post(
|
||||||
|
f"{self._base_url}/api/embeddings",
|
||||||
|
json={
|
||||||
|
"model": self._model,
|
||||||
|
"prompt": prefixed_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
embedding = data.get("embedding", [])
|
||||||
|
|
||||||
|
if not embedding:
|
||||||
|
raise EmbeddingException(
|
||||||
|
"Empty embedding returned",
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
details={"text_length": len(text), "task": task.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
embedding_256 = self._truncate_and_normalize(embedding, 256)
|
||||||
|
embedding_512 = self._truncate_and_normalize(embedding, 512)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Generated Nomic embedding: task={task.value}, "
|
||||||
|
f"dim={len(embedding)}, latency={latency_ms:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return NomicEmbeddingResult(
|
||||||
|
embedding_full=embedding,
|
||||||
|
embedding_256=embedding_256,
|
||||||
|
embedding_512=embedding_512,
|
||||||
|
dimension=len(embedding),
|
||||||
|
model=self._model,
|
||||||
|
task=task,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise EmbeddingException(
|
||||||
|
f"Ollama API error: {e.response.status_code}",
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
details={"status_code": e.response.status_code, "response": e.response.text}
|
||||||
|
)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise EmbeddingException(
|
||||||
|
f"Ollama connection error: {e}",
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
details={"base_url": self._base_url}
|
||||||
|
)
|
||||||
|
except EmbeddingException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise EmbeddingException(
|
||||||
|
f"Embedding generation failed: {e}",
|
||||||
|
provider=self.PROVIDER_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
async def embed_document(self, text: str) -> NomicEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Generate embedding for document (with search_document: prefix).
|
||||||
|
Use this when indexing documents into vector store.
|
||||||
|
"""
|
||||||
|
return await self.embed_with_task(text, EmbeddingTask.DOCUMENT)
|
||||||
|
|
||||||
|
async def embed_query(self, text: str) -> NomicEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Generate embedding for query (with search_query: prefix).
|
||||||
|
Use this when searching/retrieving documents.
|
||||||
|
"""
|
||||||
|
return await self.embed_with_task(text, EmbeddingTask.QUERY)
|
||||||
|
|
||||||
|
async def embed(self, text: str) -> list[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding vector for a single text.
|
||||||
|
Default uses QUERY task for backward compatibility.
|
||||||
|
"""
|
||||||
|
result = await self.embed_query(text)
|
||||||
|
return result.embedding_full
|
||||||
|
|
||||||
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Generate embedding vectors for multiple texts.
|
||||||
|
Uses QUERY task by default.
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
embedding = await self.embed(text)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
async def embed_documents_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[NomicEmbeddingResult]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for multiple documents (DOCUMENT task).
|
||||||
|
Use this when batch indexing documents.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for text in texts:
|
||||||
|
result = await self.embed_document(text)
|
||||||
|
results.append(result)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def embed_queries_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[NomicEmbeddingResult]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for multiple queries (QUERY task).
|
||||||
|
Use this when batch processing queries.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for text in texts:
|
||||||
|
result = await self.embed_query(text)
|
||||||
|
results.append(result)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_dimension(self) -> int:
|
||||||
|
"""Get the dimension of embedding vectors."""
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
def get_provider_name(self) -> str:
|
||||||
|
"""Get the name of this embedding provider."""
|
||||||
|
return self.PROVIDER_NAME
|
||||||
|
|
||||||
|
def get_config_schema(self) -> dict[str, Any]:
|
||||||
|
"""Get the configuration schema for Nomic provider."""
|
||||||
|
return {
|
||||||
|
"base_url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Ollama API 地址",
|
||||||
|
"default": "http://localhost:11434",
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5)",
|
||||||
|
"default": "nomic-embed-text",
|
||||||
|
},
|
||||||
|
"dimension": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "向量维度(支持 256/512/768)",
|
||||||
|
"default": 768,
|
||||||
|
},
|
||||||
|
"timeout_seconds": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "请求超时时间(秒)",
|
||||||
|
"default": 60,
|
||||||
|
},
|
||||||
|
"enable_matryoshka": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "启用 Matryoshka 维度截断",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
@ -11,6 +11,11 @@ Design reference: design.md Section 2.2 - 关键数据流
|
||||||
6. compute_confidence(...)
|
6. compute_confidence(...)
|
||||||
7. Memory.append(tenantId, sessionId, user/assistant messages)
|
7. Memory.append(tenantId, sessionId, user/assistant messages)
|
||||||
8. Return ChatResponse (or output via SSE)
|
8. Return ChatResponse (or output via SSE)
|
||||||
|
|
||||||
|
RAG Optimization (rag-optimization/spec.md):
|
||||||
|
- Two-stage retrieval with Matryoshka dimensions
|
||||||
|
- RRF hybrid ranking
|
||||||
|
- Optimized prompt engineering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -36,6 +41,16 @@ from app.services.retrieval.base import BaseRetriever, RetrievalContext, Retriev
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
OPTIMIZED_SYSTEM_PROMPT = """你是学校智能客服助手,基于提供的知识库内容回答用户问题。
|
||||||
|
|
||||||
|
回答要求:
|
||||||
|
1. 严格基于提供的知识库内容回答,不要编造信息
|
||||||
|
2. 如果知识库中没有相关信息,明确告知用户并建议转人工或稍后重试
|
||||||
|
3. 保持专业、友好的语气,回答简洁明了,突出重点
|
||||||
|
4. 如果引用知识库内容,请注明来源(如:根据[文档1]...)
|
||||||
|
5. 对于时效性问题,请提醒用户注意文档的有效期"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OrchestratorConfig:
|
class OrchestratorConfig:
|
||||||
"""
|
"""
|
||||||
|
|
@ -44,8 +59,9 @@ class OrchestratorConfig:
|
||||||
"""
|
"""
|
||||||
max_history_tokens: int = 4000
|
max_history_tokens: int = 4000
|
||||||
max_evidence_tokens: int = 2000
|
max_evidence_tokens: int = 2000
|
||||||
system_prompt: str = "你是一个智能客服助手,请根据提供的知识库内容回答用户问题。"
|
system_prompt: str = OPTIMIZED_SYSTEM_PROMPT
|
||||||
enable_rag: bool = True
|
enable_rag: bool = True
|
||||||
|
use_optimized_retriever: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -141,7 +157,14 @@ class OrchestratorService:
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
|
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
|
||||||
f"session={request.session_id}"
|
f"session={request.session_id}, channel_type={request.channel_type}, "
|
||||||
|
f"current_message={request.current_message[:100]}..."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, "
|
||||||
|
f"use_optimized_retriever={self._config.use_optimized_retriever}, "
|
||||||
|
f"llm_client={'configured' if self._llm_client else 'NOT configured'}, "
|
||||||
|
f"retriever={'configured' if self._retriever else 'NOT configured'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx = GenerationContext(
|
ctx = GenerationContext(
|
||||||
|
|
@ -257,6 +280,10 @@ class OrchestratorService:
|
||||||
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
|
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
|
||||||
Step 3 of the generation pipeline.
|
Step 3 of the generation pipeline.
|
||||||
"""
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||||
|
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
retrieval_ctx = RetrievalContext(
|
retrieval_ctx = RetrievalContext(
|
||||||
tenant_id=ctx.tenant_id,
|
tenant_id=ctx.tenant_id,
|
||||||
|
|
@ -277,11 +304,19 @@ class OrchestratorService:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
|
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
|
||||||
f"hits={ctx.retrieval_result.hit_count}, "
|
f"hits={ctx.retrieval_result.hit_count}, "
|
||||||
f"max_score={ctx.retrieval_result.max_score:.3f}"
|
f"max_score={ctx.retrieval_result.max_score:.3f}, "
|
||||||
|
f"is_empty={ctx.retrieval_result.is_empty}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctx.retrieval_result.hit_count > 0:
|
||||||
|
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
|
||||||
|
f"text_preview={hit.text[:100]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[AC-AISVC-16] Retrieval failed: {e}")
|
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
|
||||||
ctx.retrieval_result = RetrievalResult(
|
ctx.retrieval_result = RetrievalResult(
|
||||||
hits=[],
|
hits=[],
|
||||||
diagnostics={"error": str(e)},
|
diagnostics={"error": str(e)},
|
||||||
|
|
@ -294,9 +329,18 @@ class OrchestratorService:
|
||||||
Step 4-5 of the generation pipeline.
|
Step 4-5 of the generation pipeline.
|
||||||
"""
|
"""
|
||||||
messages = self._build_llm_messages(ctx)
|
messages = self._build_llm_messages(ctx)
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
|
||||||
|
f"has_retrieval_result={ctx.retrieval_result is not None}, "
|
||||||
|
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
|
||||||
|
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
|
||||||
|
)
|
||||||
|
|
||||||
if not self._llm_client:
|
if not self._llm_client:
|
||||||
logger.warning("[AC-AISVC-02] No LLM client configured, using fallback")
|
logger.warning(
|
||||||
|
f"[AC-AISVC-02] No LLM client configured, using fallback. "
|
||||||
|
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
|
||||||
|
)
|
||||||
ctx.llm_response = LLMResponse(
|
ctx.llm_response = LLMResponse(
|
||||||
content=self._fallback_response(ctx),
|
content=self._fallback_response(ctx),
|
||||||
model="fallback",
|
model="fallback",
|
||||||
|
|
@ -304,6 +348,7 @@ class OrchestratorService:
|
||||||
finish_reason="fallback",
|
finish_reason="fallback",
|
||||||
)
|
)
|
||||||
ctx.diagnostics["llm_mode"] = "fallback"
|
ctx.diagnostics["llm_mode"] = "fallback"
|
||||||
|
ctx.diagnostics["fallback_reason"] = "no_llm_client"
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -318,11 +363,16 @@ class OrchestratorService:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-02] LLM response generated: "
|
f"[AC-AISVC-02] LLM response generated: "
|
||||||
f"model={ctx.llm_response.model}, "
|
f"model={ctx.llm_response.model}, "
|
||||||
f"tokens={ctx.llm_response.usage}"
|
f"tokens={ctx.llm_response.usage}, "
|
||||||
|
f"content_preview={ctx.llm_response.content[:100]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-AISVC-02] LLM generation failed: {e}")
|
logger.error(
|
||||||
|
f"[AC-AISVC-02] LLM generation failed: {e}, "
|
||||||
|
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
ctx.llm_response = LLMResponse(
|
ctx.llm_response = LLMResponse(
|
||||||
content=self._fallback_response(ctx),
|
content=self._fallback_response(ctx),
|
||||||
model="fallback",
|
model="fallback",
|
||||||
|
|
@ -331,6 +381,8 @@ class OrchestratorService:
|
||||||
metadata={"error": str(e)},
|
metadata={"error": str(e)},
|
||||||
)
|
)
|
||||||
ctx.diagnostics["llm_error"] = str(e)
|
ctx.diagnostics["llm_error"] = str(e)
|
||||||
|
ctx.diagnostics["llm_mode"] = "fallback"
|
||||||
|
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
|
||||||
|
|
||||||
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
|
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -356,12 +408,26 @@ class OrchestratorService:
|
||||||
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
|
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-17] Format retrieval hits as evidence text.
|
[AC-AISVC-17] Format retrieval hits as evidence text.
|
||||||
|
Optimized format with source attribution and metadata.
|
||||||
"""
|
"""
|
||||||
evidence_parts = []
|
evidence_parts = []
|
||||||
for i, hit in enumerate(retrieval_result.hits[:5], 1):
|
for i, hit in enumerate(retrieval_result.hits[:5], 1):
|
||||||
evidence_parts.append(f"[{i}] (相关度: {hit.score:.2f}) {hit.text}")
|
metadata = hit.metadata or {}
|
||||||
|
source = metadata.get("metadata", {}).get("source_doc", "知识库")
|
||||||
|
category = metadata.get("metadata", {}).get("category", "")
|
||||||
|
department = metadata.get("metadata", {}).get("department", "")
|
||||||
|
|
||||||
return "\n".join(evidence_parts)
|
header = f"[文档{i}]"
|
||||||
|
if source and source != "知识库":
|
||||||
|
header += f" 来源:{source}"
|
||||||
|
if category:
|
||||||
|
header += f" | 类别:{category}"
|
||||||
|
if department:
|
||||||
|
header += f" | 部门:{department}"
|
||||||
|
|
||||||
|
evidence_parts.append(f"{header}\n相关度:{hit.score:.2f}\n内容:{hit.text}")
|
||||||
|
|
||||||
|
return "\n\n".join(evidence_parts)
|
||||||
|
|
||||||
def _fallback_response(self, ctx: GenerationContext) -> str:
|
def _fallback_response(self, ctx: GenerationContext) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Retrieval module for AI Service.
|
Retrieval module for AI Service.
|
||||||
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
|
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
|
||||||
|
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.services.retrieval.base import (
|
from app.services.retrieval.base import (
|
||||||
|
|
@ -10,6 +11,27 @@ from app.services.retrieval.base import (
|
||||||
RetrievalResult,
|
RetrievalResult,
|
||||||
)
|
)
|
||||||
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
|
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
|
||||||
|
from app.services.retrieval.metadata import (
|
||||||
|
ChunkMetadata,
|
||||||
|
ChunkMetadataModel,
|
||||||
|
MetadataFilter,
|
||||||
|
KnowledgeChunk,
|
||||||
|
RetrieveRequest,
|
||||||
|
RetrieveResult,
|
||||||
|
RetrievalStrategy,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.optimized_retriever import (
|
||||||
|
OptimizedRetriever,
|
||||||
|
get_optimized_retriever,
|
||||||
|
TwoStageResult,
|
||||||
|
RRFCombiner,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.indexer import (
|
||||||
|
KnowledgeIndexer,
|
||||||
|
get_knowledge_indexer,
|
||||||
|
IndexingProgress,
|
||||||
|
IndexingResult,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseRetriever",
|
"BaseRetriever",
|
||||||
|
|
@ -18,4 +40,18 @@ __all__ = [
|
||||||
"RetrievalResult",
|
"RetrievalResult",
|
||||||
"VectorRetriever",
|
"VectorRetriever",
|
||||||
"get_vector_retriever",
|
"get_vector_retriever",
|
||||||
|
"ChunkMetadata",
|
||||||
|
"MetadataFilter",
|
||||||
|
"KnowledgeChunk",
|
||||||
|
"RetrieveRequest",
|
||||||
|
"RetrieveResult",
|
||||||
|
"RetrievalStrategy",
|
||||||
|
"OptimizedRetriever",
|
||||||
|
"get_optimized_retriever",
|
||||||
|
"TwoStageResult",
|
||||||
|
"RRFCombiner",
|
||||||
|
"KnowledgeIndexer",
|
||||||
|
"get_knowledge_indexer",
|
||||||
|
"IndexingProgress",
|
||||||
|
"IndexingResult",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,339 @@
|
||||||
|
"""
|
||||||
|
Knowledge base indexing service with optimized embedding.
|
||||||
|
Reference: rag-optimization/spec.md Section 5.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||||
|
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
||||||
|
from app.services.retrieval.metadata import ChunkMetadata, KnowledgeChunk
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexingProgress:
|
||||||
|
"""Progress tracking for indexing jobs."""
|
||||||
|
total_chunks: int = 0
|
||||||
|
processed_chunks: int = 0
|
||||||
|
failed_chunks: int = 0
|
||||||
|
current_document: str = ""
|
||||||
|
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress_percent(self) -> int:
|
||||||
|
if self.total_chunks == 0:
|
||||||
|
return 0
|
||||||
|
return int((self.processed_chunks / self.total_chunks) * 100)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def elapsed_seconds(self) -> float:
|
||||||
|
return (datetime.utcnow() - self.started_at).total_seconds()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexingResult:
|
||||||
|
"""Result of an indexing operation."""
|
||||||
|
success: bool
|
||||||
|
total_chunks: int
|
||||||
|
indexed_chunks: int
|
||||||
|
failed_chunks: int
|
||||||
|
elapsed_seconds: float
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexer:
|
||||||
|
"""
|
||||||
|
Knowledge base indexer with optimized embedding.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Task prefixes (search_document:) for document embedding
|
||||||
|
- Multi-dimensional vectors (256/512/768)
|
||||||
|
- Metadata support
|
||||||
|
- Batch processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qdrant_client: QdrantClient | None = None,
|
||||||
|
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = 50,
|
||||||
|
batch_size: int = 10,
|
||||||
|
):
|
||||||
|
self._qdrant_client = qdrant_client
|
||||||
|
self._embedding_provider = embedding_provider
|
||||||
|
self._chunk_size = chunk_size
|
||||||
|
self._chunk_overlap = chunk_overlap
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._progress: IndexingProgress | None = None
|
||||||
|
|
||||||
|
async def _get_client(self) -> QdrantClient:
|
||||||
|
if self._qdrant_client is None:
|
||||||
|
self._qdrant_client = await get_qdrant_client()
|
||||||
|
return self._qdrant_client
|
||||||
|
|
||||||
|
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||||
|
if self._embedding_provider is None:
|
||||||
|
self._embedding_provider = NomicEmbeddingProvider(
|
||||||
|
base_url=settings.ollama_base_url,
|
||||||
|
model=settings.ollama_embedding_model,
|
||||||
|
dimension=settings.qdrant_vector_size,
|
||||||
|
)
|
||||||
|
return self._embedding_provider
|
||||||
|
|
||||||
|
def chunk_text(self, text: str, metadata: ChunkMetadata | None = None) -> list[KnowledgeChunk]:
|
||||||
|
"""
|
||||||
|
Split text into chunks for indexing.
|
||||||
|
Each line becomes a separate chunk for better retrieval granularity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Full text to chunk
|
||||||
|
metadata: Metadata to attach to each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeChunk objects
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
lines = text.split('\n')
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
if len(line) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = KnowledgeChunk(
|
||||||
|
chunk_id=f"{doc_id}_{i}",
|
||||||
|
document_id=doc_id,
|
||||||
|
content=line,
|
||||||
|
metadata=metadata or ChunkMetadata(),
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def chunk_text_by_lines(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
metadata: ChunkMetadata | None = None,
|
||||||
|
min_line_length: int = 10,
|
||||||
|
merge_short_lines: bool = False,
|
||||||
|
) -> list[KnowledgeChunk]:
|
||||||
|
"""
|
||||||
|
Split text by lines, each line is a separate chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Full text to chunk
|
||||||
|
metadata: Metadata to attach to each chunk
|
||||||
|
min_line_length: Minimum line length to be indexed
|
||||||
|
merge_short_lines: Whether to merge consecutive short lines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeChunk objects
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
lines = text.split('\n')
|
||||||
|
|
||||||
|
if merge_short_lines:
|
||||||
|
merged_lines = []
|
||||||
|
current_line = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
if current_line:
|
||||||
|
merged_lines.append(current_line)
|
||||||
|
current_line = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_line:
|
||||||
|
current_line += " " + line
|
||||||
|
else:
|
||||||
|
current_line = line
|
||||||
|
|
||||||
|
if len(current_line) >= min_line_length * 2:
|
||||||
|
merged_lines.append(current_line)
|
||||||
|
current_line = ""
|
||||||
|
|
||||||
|
if current_line:
|
||||||
|
merged_lines.append(current_line)
|
||||||
|
|
||||||
|
lines = merged_lines
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
if len(line) < min_line_length:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = KnowledgeChunk(
|
||||||
|
chunk_id=f"{doc_id}_{i}",
|
||||||
|
document_id=doc_id,
|
||||||
|
content=line,
|
||||||
|
metadata=metadata or ChunkMetadata(),
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
async def index_document(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
document_id: str,
|
||||||
|
text: str,
|
||||||
|
metadata: ChunkMetadata | None = None,
|
||||||
|
) -> IndexingResult:
|
||||||
|
"""
|
||||||
|
Index a single document with optimized embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier
|
||||||
|
document_id: Document identifier
|
||||||
|
text: Document text content
|
||||||
|
metadata: Optional metadata for the document
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndexingResult with status and statistics
|
||||||
|
"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
provider = await self._get_embedding_provider()
|
||||||
|
|
||||||
|
await client.ensure_collection_exists(tenant_id, use_multi_vector=True)
|
||||||
|
|
||||||
|
chunks = self.chunk_text(text, metadata)
|
||||||
|
|
||||||
|
self._progress = IndexingProgress(
|
||||||
|
total_chunks=len(chunks),
|
||||||
|
current_document=document_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
try:
|
||||||
|
embedding_result = await provider.embed_document(chunk.content)
|
||||||
|
|
||||||
|
chunk.embedding_full = embedding_result.embedding_full
|
||||||
|
chunk.embedding_256 = embedding_result.embedding_256
|
||||||
|
chunk.embedding_512 = embedding_result.embedding_512
|
||||||
|
|
||||||
|
point = {
|
||||||
|
"id": str(uuid.uuid4()), # Generate a valid UUID for Qdrant
|
||||||
|
"vector": {
|
||||||
|
"full": chunk.embedding_full,
|
||||||
|
"dim_256": chunk.embedding_256,
|
||||||
|
"dim_512": chunk.embedding_512,
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"document_id": document_id,
|
||||||
|
"text": chunk.content,
|
||||||
|
"metadata": chunk.metadata.to_dict(),
|
||||||
|
"created_at": chunk.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
points.append(point)
|
||||||
|
|
||||||
|
self._progress.processed_chunks += 1
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[RAG-OPT] Indexed chunk {i+1}/{len(chunks)} for doc={document_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[RAG-OPT] Failed to index chunk {i}: {e}")
|
||||||
|
self._progress.failed_chunks += 1
|
||||||
|
|
||||||
|
if points:
|
||||||
|
await client.upsert_multi_vector(tenant_id, points)
|
||||||
|
|
||||||
|
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Indexed document {document_id}: "
|
||||||
|
f"{len(points)} chunks in {elapsed:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return IndexingResult(
|
||||||
|
success=True,
|
||||||
|
total_chunks=len(chunks),
|
||||||
|
indexed_chunks=len(points),
|
||||||
|
failed_chunks=self._progress.failed_chunks,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||||
|
logger.error(f"[RAG-OPT] Failed to index document {document_id}: {e}")
|
||||||
|
|
||||||
|
return IndexingResult(
|
||||||
|
success=False,
|
||||||
|
total_chunks=0,
|
||||||
|
indexed_chunks=0,
|
||||||
|
failed_chunks=0,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def index_documents_batch(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
documents: list[dict[str, Any]],
|
||||||
|
) -> list[IndexingResult]:
|
||||||
|
"""
|
||||||
|
Index multiple documents in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier
|
||||||
|
documents: List of documents with format:
|
||||||
|
{
|
||||||
|
"document_id": str,
|
||||||
|
"text": str,
|
||||||
|
"metadata": ChunkMetadata (optional)
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of IndexingResult for each document
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
result = await self.index_document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
document_id=doc["document_id"],
|
||||||
|
text=doc["text"],
|
||||||
|
metadata=doc.get("metadata"),
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_progress(self) -> IndexingProgress | None:
|
||||||
|
"""Get current indexing progress."""
|
||||||
|
return self._progress
|
||||||
|
|
||||||
|
|
||||||
|
_knowledge_indexer: KnowledgeIndexer | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_knowledge_indexer() -> KnowledgeIndexer:
|
||||||
|
"""Get or create KnowledgeIndexer instance."""
|
||||||
|
global _knowledge_indexer
|
||||||
|
if _knowledge_indexer is None:
|
||||||
|
_knowledge_indexer = KnowledgeIndexer()
|
||||||
|
return _knowledge_indexer
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""
|
||||||
|
Metadata models for RAG optimization.
|
||||||
|
Implements structured metadata for knowledge chunks.
|
||||||
|
Reference: rag-optimization/spec.md Section 3.2
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import date, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalStrategy(str, Enum):
|
||||||
|
"""Retrieval strategy options."""
|
||||||
|
VECTOR_ONLY = "vector"
|
||||||
|
BM25_ONLY = "bm25"
|
||||||
|
HYBRID = "hybrid"
|
||||||
|
TWO_STAGE = "two_stage"
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkMetadataModel(BaseModel):
|
||||||
|
"""Pydantic model for API serialization."""
|
||||||
|
category: str = ""
|
||||||
|
subcategory: str = ""
|
||||||
|
target_audience: list[str] = []
|
||||||
|
source_doc: str = ""
|
||||||
|
source_url: str = ""
|
||||||
|
department: str = ""
|
||||||
|
valid_from: str | None = None
|
||||||
|
valid_until: str | None = None
|
||||||
|
priority: int = 5
|
||||||
|
keywords: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkMetadata:
|
||||||
|
"""
|
||||||
|
Metadata for knowledge chunks.
|
||||||
|
Reference: rag-optimization/spec.md Section 3.2.2
|
||||||
|
"""
|
||||||
|
category: str = ""
|
||||||
|
subcategory: str = ""
|
||||||
|
target_audience: list[str] = field(default_factory=list)
|
||||||
|
source_doc: str = ""
|
||||||
|
source_url: str = ""
|
||||||
|
department: str = ""
|
||||||
|
valid_from: date | None = None
|
||||||
|
valid_until: date | None = None
|
||||||
|
priority: int = 5
|
||||||
|
keywords: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for storage."""
|
||||||
|
return {
|
||||||
|
"category": self.category,
|
||||||
|
"subcategory": self.subcategory,
|
||||||
|
"target_audience": self.target_audience,
|
||||||
|
"source_doc": self.source_doc,
|
||||||
|
"source_url": self.source_url,
|
||||||
|
"department": self.department,
|
||||||
|
"valid_from": self.valid_from.isoformat() if self.valid_from else None,
|
||||||
|
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
|
||||||
|
"priority": self.priority,
|
||||||
|
"keywords": self.keywords,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
category=data.get("category", ""),
|
||||||
|
subcategory=data.get("subcategory", ""),
|
||||||
|
target_audience=data.get("target_audience", []),
|
||||||
|
source_doc=data.get("source_doc", ""),
|
||||||
|
source_url=data.get("source_url", ""),
|
||||||
|
department=data.get("department", ""),
|
||||||
|
valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None,
|
||||||
|
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
|
||||||
|
priority=data.get("priority", 5),
|
||||||
|
keywords=data.get("keywords", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataFilter:
|
||||||
|
"""
|
||||||
|
Filter conditions for metadata-based retrieval.
|
||||||
|
Reference: rag-optimization/spec.md Section 4.1
|
||||||
|
"""
|
||||||
|
categories: list[str] | None = None
|
||||||
|
target_audiences: list[str] | None = None
|
||||||
|
departments: list[str] | None = None
|
||||||
|
valid_only: bool = True
|
||||||
|
min_priority: int | None = None
|
||||||
|
keywords: list[str] | None = None
|
||||||
|
|
||||||
|
def to_qdrant_filter(self) -> dict[str, Any] | None:
|
||||||
|
"""Convert to Qdrant filter format."""
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
if self.categories:
|
||||||
|
conditions.append({
|
||||||
|
"key": "metadata.category",
|
||||||
|
"match": {"any": self.categories}
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.departments:
|
||||||
|
conditions.append({
|
||||||
|
"key": "metadata.department",
|
||||||
|
"match": {"any": self.departments}
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.target_audiences:
|
||||||
|
conditions.append({
|
||||||
|
"key": "metadata.target_audience",
|
||||||
|
"match": {"any": self.target_audiences}
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.valid_only:
|
||||||
|
today = date.today().isoformat()
|
||||||
|
conditions.append({
|
||||||
|
"should": [
|
||||||
|
{"key": "metadata.valid_until", "match": {"value": None}},
|
||||||
|
{"key": "metadata.valid_until", "range": {"gte": today}}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.min_priority is not None:
|
||||||
|
conditions.append({
|
||||||
|
"key": "metadata.priority",
|
||||||
|
"range": {"lte": self.min_priority}
|
||||||
|
})
|
||||||
|
|
||||||
|
if not conditions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(conditions) == 1:
|
||||||
|
return {"must": conditions}
|
||||||
|
|
||||||
|
return {"must": conditions}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KnowledgeChunk:
|
||||||
|
"""
|
||||||
|
Knowledge chunk with multi-dimensional embeddings.
|
||||||
|
Reference: rag-optimization/spec.md Section 3.2.1
|
||||||
|
"""
|
||||||
|
chunk_id: str
|
||||||
|
document_id: str
|
||||||
|
content: str
|
||||||
|
embedding_full: list[float] = field(default_factory=list)
|
||||||
|
embedding_256: list[float] = field(default_factory=list)
|
||||||
|
embedding_512: list[float] = field(default_factory=list)
|
||||||
|
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]:
|
||||||
|
"""Convert to Qdrant point format."""
|
||||||
|
return {
|
||||||
|
"id": point_id,
|
||||||
|
"vector": {
|
||||||
|
"full": self.embedding_full,
|
||||||
|
"dim_256": self.embedding_256,
|
||||||
|
"dim_512": self.embedding_512,
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"chunk_id": self.chunk_id,
|
||||||
|
"document_id": self.document_id,
|
||||||
|
"text": self.content,
|
||||||
|
"metadata": self.metadata.to_dict(),
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrieveRequest:
|
||||||
|
"""
|
||||||
|
Request for knowledge retrieval.
|
||||||
|
Reference: rag-optimization/spec.md Section 4.1
|
||||||
|
"""
|
||||||
|
query: str
|
||||||
|
query_with_prefix: str = ""
|
||||||
|
top_k: int = 10
|
||||||
|
filters: MetadataFilter | None = None
|
||||||
|
strategy: RetrievalStrategy = RetrievalStrategy.HYBRID
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.query_with_prefix:
|
||||||
|
self.query_with_prefix = f"search_query:{self.query}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrieveResult:
|
||||||
|
"""
|
||||||
|
Result from knowledge retrieval.
|
||||||
|
Reference: rag-optimization/spec.md Section 4.1
|
||||||
|
"""
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
vector_score: float = 0.0
|
||||||
|
bm25_score: float = 0.0
|
||||||
|
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
||||||
|
rank: int = 0
|
||||||
|
|
@ -0,0 +1,509 @@
|
||||||
|
"""
|
||||||
|
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
|
||||||
|
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||||
|
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
||||||
|
from app.services.retrieval.base import (
|
||||||
|
BaseRetriever,
|
||||||
|
RetrievalContext,
|
||||||
|
RetrievalHit,
|
||||||
|
RetrievalResult,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.metadata import (
|
||||||
|
ChunkMetadata,
|
||||||
|
MetadataFilter,
|
||||||
|
RetrieveResult,
|
||||||
|
RetrievalStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TwoStageResult:
|
||||||
|
"""Result from two-stage retrieval."""
|
||||||
|
candidates: list[dict[str, Any]]
|
||||||
|
final_results: list[RetrieveResult]
|
||||||
|
stage1_latency_ms: float = 0.0
|
||||||
|
stage2_latency_ms: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class RRFCombiner:
|
||||||
|
"""
|
||||||
|
Reciprocal Rank Fusion for combining multiple retrieval results.
|
||||||
|
Reference: rag-optimization/spec.md Section 2.5
|
||||||
|
|
||||||
|
Formula: score = Σ(1 / (k + rank_i))
|
||||||
|
Default k = 60
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, k: int = 60):
|
||||||
|
self._k = k
|
||||||
|
|
||||||
|
def combine(
|
||||||
|
self,
|
||||||
|
vector_results: list[dict[str, Any]],
|
||||||
|
bm25_results: list[dict[str, Any]],
|
||||||
|
vector_weight: float = 0.7,
|
||||||
|
bm25_weight: float = 0.3,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Combine vector and BM25 results using RRF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_results: Results from vector search
|
||||||
|
bm25_results: Results from BM25 search
|
||||||
|
vector_weight: Weight for vector results
|
||||||
|
bm25_weight: Weight for BM25 results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined and sorted results
|
||||||
|
"""
|
||||||
|
combined_scores: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
for rank, result in enumerate(vector_results):
|
||||||
|
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
||||||
|
rrf_score = vector_weight / (self._k + rank + 1)
|
||||||
|
|
||||||
|
if chunk_id not in combined_scores:
|
||||||
|
combined_scores[chunk_id] = {
|
||||||
|
"score": 0.0,
|
||||||
|
"vector_score": result.get("score", 0.0),
|
||||||
|
"bm25_score": 0.0,
|
||||||
|
"vector_rank": rank,
|
||||||
|
"bm25_rank": -1,
|
||||||
|
"payload": result.get("payload", {}),
|
||||||
|
"id": chunk_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
combined_scores[chunk_id]["score"] += rrf_score
|
||||||
|
|
||||||
|
for rank, result in enumerate(bm25_results):
|
||||||
|
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
||||||
|
rrf_score = bm25_weight / (self._k + rank + 1)
|
||||||
|
|
||||||
|
if chunk_id not in combined_scores:
|
||||||
|
combined_scores[chunk_id] = {
|
||||||
|
"score": 0.0,
|
||||||
|
"vector_score": 0.0,
|
||||||
|
"bm25_score": result.get("score", 0.0),
|
||||||
|
"vector_rank": -1,
|
||||||
|
"bm25_rank": rank,
|
||||||
|
"payload": result.get("payload", {}),
|
||||||
|
"id": chunk_id,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
||||||
|
combined_scores[chunk_id]["bm25_rank"] = rank
|
||||||
|
|
||||||
|
combined_scores[chunk_id]["score"] += rrf_score
|
||||||
|
|
||||||
|
sorted_results = sorted(
|
||||||
|
combined_scores.values(),
|
||||||
|
key=lambda x: x["score"],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted_results
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedRetriever(BaseRetriever):
|
||||||
|
"""
|
||||||
|
Optimized retriever with:
|
||||||
|
- Task prefixes (search_document/search_query)
|
||||||
|
- Two-stage retrieval (256 dim -> 768 dim)
|
||||||
|
- RRF hybrid ranking (vector + BM25)
|
||||||
|
- Metadata filtering
|
||||||
|
|
||||||
|
Reference: rag-optimization/spec.md Section 2, 3, 4
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qdrant_client: QdrantClient | None = None,
|
||||||
|
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||||
|
top_k: int | None = None,
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
min_hits: int | None = None,
|
||||||
|
two_stage_enabled: bool | None = None,
|
||||||
|
two_stage_expand_factor: int | None = None,
|
||||||
|
hybrid_enabled: bool | None = None,
|
||||||
|
rrf_k: int | None = None,
|
||||||
|
):
|
||||||
|
self._qdrant_client = qdrant_client
|
||||||
|
self._embedding_provider = embedding_provider
|
||||||
|
self._top_k = top_k or settings.rag_top_k
|
||||||
|
self._score_threshold = score_threshold or settings.rag_score_threshold
|
||||||
|
self._min_hits = min_hits or settings.rag_min_hits
|
||||||
|
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
|
||||||
|
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
|
||||||
|
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
|
||||||
|
self._rrf_k = rrf_k or settings.rag_rrf_k
|
||||||
|
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
|
||||||
|
|
||||||
|
async def _get_client(self) -> QdrantClient:
|
||||||
|
if self._qdrant_client is None:
|
||||||
|
self._qdrant_client = await get_qdrant_client()
|
||||||
|
return self._qdrant_client
|
||||||
|
|
||||||
|
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||||
|
if self._embedding_provider is None:
|
||||||
|
from app.services.embedding.factory import get_embedding_config_manager
|
||||||
|
manager = get_embedding_config_manager()
|
||||||
|
provider = await manager.get_provider()
|
||||||
|
if isinstance(provider, NomicEmbeddingProvider):
|
||||||
|
self._embedding_provider = provider
|
||||||
|
else:
|
||||||
|
self._embedding_provider = NomicEmbeddingProvider(
|
||||||
|
base_url=settings.ollama_base_url,
|
||||||
|
model=settings.ollama_embedding_model,
|
||||||
|
dimension=settings.qdrant_vector_size,
|
||||||
|
)
|
||||||
|
return self._embedding_provider
|
||||||
|
|
||||||
|
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||||
|
"""
|
||||||
|
Retrieve documents using optimized strategy.
|
||||||
|
|
||||||
|
Strategy selection:
|
||||||
|
1. If two_stage_enabled: use two-stage retrieval
|
||||||
|
2. If hybrid_enabled: use RRF hybrid ranking
|
||||||
|
3. Otherwise: simple vector search
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
|
||||||
|
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
|
||||||
|
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = await self._get_embedding_provider()
|
||||||
|
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
|
||||||
|
|
||||||
|
embedding_result = await provider.embed_query(ctx.query)
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
|
||||||
|
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._two_stage_enabled:
|
||||||
|
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
||||||
|
results = await self._two_stage_retrieve(
|
||||||
|
ctx.tenant_id,
|
||||||
|
embedding_result,
|
||||||
|
self._top_k,
|
||||||
|
)
|
||||||
|
elif self._hybrid_enabled:
|
||||||
|
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
|
||||||
|
results = await self._hybrid_retrieve(
|
||||||
|
ctx.tenant_id,
|
||||||
|
embedding_result,
|
||||||
|
ctx.query,
|
||||||
|
self._top_k,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
|
||||||
|
results = await self._vector_retrieve(
|
||||||
|
ctx.tenant_id,
|
||||||
|
embedding_result.embedding_full,
|
||||||
|
self._top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
|
||||||
|
|
||||||
|
retrieval_hits = [
|
||||||
|
RetrievalHit(
|
||||||
|
text=result.get("payload", {}).get("text", ""),
|
||||||
|
score=result.get("score", 0.0),
|
||||||
|
source="optimized_rag",
|
||||||
|
metadata=result.get("payload", {}),
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
if result.get("score", 0.0) >= self._score_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_count = len(results) - len(retrieval_hits)
|
||||||
|
if filtered_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_insufficient = len(retrieval_hits) < self._min_hits
|
||||||
|
|
||||||
|
diagnostics = {
|
||||||
|
"query_length": len(ctx.query),
|
||||||
|
"top_k": self._top_k,
|
||||||
|
"score_threshold": self._score_threshold,
|
||||||
|
"two_stage_enabled": self._two_stage_enabled,
|
||||||
|
"hybrid_enabled": self._hybrid_enabled,
|
||||||
|
"total_hits": len(retrieval_hits),
|
||||||
|
"is_insufficient": is_insufficient,
|
||||||
|
"max_score": max((h.score for h in retrieval_hits), default=0.0),
|
||||||
|
"raw_results_count": len(results),
|
||||||
|
"filtered_below_threshold": filtered_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
|
||||||
|
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(retrieval_hits) == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
|
||||||
|
f"raw_results={len(results)}, threshold={self._score_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
hits=retrieval_hits,
|
||||||
|
diagnostics=diagnostics,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
|
||||||
|
return RetrievalResult(
|
||||||
|
hits=[],
|
||||||
|
diagnostics={"error": str(e), "is_insufficient": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _two_stage_retrieve(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
embedding_result: NomicEmbeddingResult,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Two-stage retrieval using Matryoshka dimensions.
|
||||||
|
|
||||||
|
Stage 1: Fast retrieval with 256-dim vectors
|
||||||
|
Stage 2: Precise reranking with 768-dim vectors
|
||||||
|
|
||||||
|
Reference: rag-optimization/spec.md Section 2.4
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
stage1_start = time.perf_counter()
|
||||||
|
candidates = await self._search_with_dimension(
|
||||||
|
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||||
|
top_k * self._two_stage_expand_factor
|
||||||
|
)
|
||||||
|
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
stage2_start = time.perf_counter()
|
||||||
|
reranked = []
|
||||||
|
for candidate in candidates:
|
||||||
|
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
|
||||||
|
if stored_full_embedding:
|
||||||
|
import numpy as np
|
||||||
|
similarity = self._cosine_similarity(
|
||||||
|
embedding_result.embedding_full,
|
||||||
|
stored_full_embedding
|
||||||
|
)
|
||||||
|
candidate["score"] = similarity
|
||||||
|
candidate["stage"] = "reranked"
|
||||||
|
reranked.append(candidate)
|
||||||
|
|
||||||
|
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||||
|
results = reranked[:top_k]
|
||||||
|
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _hybrid_retrieve(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
embedding_result: NomicEmbeddingResult,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Hybrid retrieval using RRF to combine vector and BM25 results.
|
||||||
|
|
||||||
|
Reference: rag-optimization/spec.md Section 2.5
|
||||||
|
"""
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
vector_task = self._search_with_dimension(
|
||||||
|
client, tenant_id, embedding_result.embedding_full, "full",
|
||||||
|
top_k * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
|
||||||
|
|
||||||
|
vector_results, bm25_results = await asyncio.gather(
|
||||||
|
vector_task, bm25_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(vector_results, Exception):
|
||||||
|
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
|
||||||
|
vector_results = []
|
||||||
|
|
||||||
|
if isinstance(bm25_results, Exception):
|
||||||
|
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
|
||||||
|
bm25_results = []
|
||||||
|
|
||||||
|
combined = self._rrf_combiner.combine(
|
||||||
|
vector_results,
|
||||||
|
bm25_results,
|
||||||
|
vector_weight=settings.rag_vector_weight,
|
||||||
|
bm25_weight=settings.rag_bm25_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined[:top_k]
|
||||||
|
|
||||||
|
async def _vector_retrieve(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
embedding: list[float],
|
||||||
|
top_k: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Simple vector retrieval."""
|
||||||
|
client = await self._get_client()
|
||||||
|
return await self._search_with_dimension(
|
||||||
|
client, tenant_id, embedding, "full", top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _search_with_dimension(
|
||||||
|
self,
|
||||||
|
client: QdrantClient,
|
||||||
|
tenant_id: str,
|
||||||
|
query_vector: list[float],
|
||||||
|
vector_name: str,
|
||||||
|
limit: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search using specified vector dimension."""
|
||||||
|
try:
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
collection_name = client.get_collection_name(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Searching collection={collection_name}, "
|
||||||
|
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await qdrant.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query_vector=(vector_name, query_vector),
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(results) > 0:
|
||||||
|
for i, r in enumerate(results[:3]):
|
||||||
|
logger.debug(
|
||||||
|
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(result.id),
|
||||||
|
"score": result.score,
|
||||||
|
"payload": result.payload or {},
|
||||||
|
}
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
|
||||||
|
f"collection_name={client.get_collection_name(tenant_id)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _bm25_search(
|
||||||
|
self,
|
||||||
|
client: QdrantClient,
|
||||||
|
tenant_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
|
||||||
|
This is a simplified implementation; for production, use Elasticsearch.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
collection_name = client.get_collection_name(tenant_id)
|
||||||
|
|
||||||
|
query_terms = set(re.findall(r'\w+', query.lower()))
|
||||||
|
|
||||||
|
results = await qdrant.scroll(
|
||||||
|
collection_name=collection_name,
|
||||||
|
limit=limit * 3,
|
||||||
|
with_payload=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored_results = []
|
||||||
|
for point in results[0]:
|
||||||
|
text = point.payload.get("text", "").lower()
|
||||||
|
text_terms = set(re.findall(r'\w+', text))
|
||||||
|
overlap = len(query_terms & text_terms)
|
||||||
|
if overlap > 0:
|
||||||
|
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
||||||
|
scored_results.append({
|
||||||
|
"id": str(point.id),
|
||||||
|
"score": score,
|
||||||
|
"payload": point.payload or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
return scored_results[:limit]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||||
|
"""Calculate cosine similarity between two vectors."""
|
||||||
|
import numpy as np
|
||||||
|
a = np.array(vec1)
|
||||||
|
b = np.array(vec2)
|
||||||
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""Check if retriever is healthy."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
await qdrant.get_collections()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[RAG-OPT] Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
_optimized_retriever: OptimizedRetriever | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optimized_retriever() -> OptimizedRetriever:
|
||||||
|
"""Get or create OptimizedRetriever instance."""
|
||||||
|
global _optimized_retriever
|
||||||
|
if _optimized_retriever is None:
|
||||||
|
_optimized_retriever = OptimizedRetriever()
|
||||||
|
return _optimized_retriever
|
||||||
Loading…
Reference in New Issue