ai-robot-core/ai-service/app/api/admin/rag.py

162 lines
5.0 KiB
Python

"""
RAG Lab endpoints for debugging and experimentation.
[AC-ASA-05] RAG experiment debugging with retrieval results and prompt visualization.
"""
import logging
from typing import Annotated, Any, List
from fastapi import APIRouter, Depends, Body
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import get_settings
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.core.qdrant_client import get_qdrant_client
from app.models import ErrorResponse
from app.services.retrieval.vector_retriever import get_vector_retriever
from app.services.retrieval.base import RetrievalContext
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
class RAGExperimentRequest(BaseModel):
query: str = Field(..., description="Query text for retrieval")
kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search")
params: dict[str, Any] | None = Field(default=None, description="Retrieval parameters")
@router.post(
"/experiments/run",
operation_id="runRagExperiment",
summary="Run RAG debugging experiment",
description="[AC-ASA-05] Trigger RAG experiment with retrieval and prompt generation.",
responses={
200: {"description": "Experiment results with retrieval and prompt"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def run_rag_experiment(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
request: RAGExperimentRequest = Body(...),
) -> JSONResponse:
"""
[AC-ASA-05] Run RAG experiment and return retrieval results with final prompt.
"""
logger.info(
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
f"query={request.query[:50]}..., kb_ids={request.kb_ids}"
)
settings = get_settings()
params = request.params or {}
top_k = params.get("topK", settings.rag_top_k)
threshold = params.get("threshold", settings.rag_score_threshold)
try:
retriever = await get_vector_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=request.query,
session_id="rag_experiment",
channel_type="admin",
metadata={"kb_ids": request.kb_ids},
)
result = await retriever.retrieve(retrieval_ctx)
retrieval_results = [
{
"content": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
]
final_prompt = _build_final_prompt(request.query, retrieval_results)
logger.info(
f"[AC-ASA-05] RAG experiment complete: hits={len(retrieval_results)}, "
f"max_score={result.max_score:.3f}"
)
return JSONResponse(
content={
"retrievalResults": retrieval_results,
"finalPrompt": final_prompt,
"diagnostics": result.diagnostics,
}
)
except Exception as e:
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
fallback_results = _get_fallback_results(request.query)
fallback_prompt = _build_final_prompt(request.query, fallback_results)
return JSONResponse(
content={
"retrievalResults": fallback_results,
"finalPrompt": fallback_prompt,
"diagnostics": {
"error": str(e),
"fallback": True,
},
}
)
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
"""
Build the final prompt from query and retrieval results.
"""
if not retrieval_results:
return f"""用户问题:{query}
未找到相关检索结果,请基于通用知识回答用户问题。"""
evidence_text = "\n".join([
f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}"
for i, hit in enumerate(retrieval_results[:5])
])
return f"""基于以下检索到的信息,回答用户问题:
用户问题:{query}
检索结果:
{evidence_text}
请基于以上信息生成专业、准确的回答。"""
def _get_fallback_results(query: str) -> list[dict]:
"""
Provide fallback results when retrieval fails.
"""
return [
{
"content": "检索服务暂时不可用,这是模拟结果。",
"score": 0.5,
"source": "fallback",
}
]