""" RAG Lab endpoints for debugging and experimentation. [AC-ASA-05] RAG experiment debugging with retrieval results and prompt visualization. """ import logging from typing import Annotated, Any, List from fastapi import APIRouter, Depends, Body from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings from app.core.database import get_session from app.core.exceptions import MissingTenantIdException from app.core.tenant import get_tenant_id from app.core.qdrant_client import get_qdrant_client from app.models import ErrorResponse from app.services.retrieval.vector_retriever import get_vector_retriever from app.services.retrieval.base import RetrievalContext logger = logging.getLogger(__name__) router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"]) def get_current_tenant_id() -> str: """Dependency to get current tenant ID or raise exception.""" tenant_id = get_tenant_id() if not tenant_id: raise MissingTenantIdException() return tenant_id class RAGExperimentRequest(BaseModel): query: str = Field(..., description="Query text for retrieval") kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search") params: dict[str, Any] | None = Field(default=None, description="Retrieval parameters") @router.post( "/experiments/run", operation_id="runRagExperiment", summary="Run RAG debugging experiment", description="[AC-ASA-05] Trigger RAG experiment with retrieval and prompt generation.", responses={ 200: {"description": "Experiment results with retrieval and prompt"}, 401: {"description": "Unauthorized", "model": ErrorResponse}, 403: {"description": "Forbidden", "model": ErrorResponse}, }, ) async def run_rag_experiment( tenant_id: Annotated[str, Depends(get_current_tenant_id)], request: RAGExperimentRequest = Body(...), ) -> JSONResponse: """ [AC-ASA-05] Run RAG experiment and return retrieval results with final prompt. """ logger.info( f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, " f"query={request.query[:50]}..., kb_ids={request.kb_ids}" ) settings = get_settings() params = request.params or {} top_k = params.get("topK", settings.rag_top_k) threshold = params.get("threshold", settings.rag_score_threshold) try: retriever = await get_vector_retriever() retrieval_ctx = RetrievalContext( tenant_id=tenant_id, query=request.query, session_id="rag_experiment", channel_type="admin", metadata={"kb_ids": request.kb_ids}, ) result = await retriever.retrieve(retrieval_ctx) retrieval_results = [ { "content": hit.text, "score": hit.score, "source": hit.source, "metadata": hit.metadata, } for hit in result.hits ] final_prompt = _build_final_prompt(request.query, retrieval_results) logger.info( f"[AC-ASA-05] RAG experiment complete: hits={len(retrieval_results)}, " f"max_score={result.max_score:.3f}" ) return JSONResponse( content={ "retrievalResults": retrieval_results, "finalPrompt": final_prompt, "diagnostics": result.diagnostics, } ) except Exception as e: logger.error(f"[AC-ASA-05] RAG experiment failed: {e}") fallback_results = _get_fallback_results(request.query) fallback_prompt = _build_final_prompt(request.query, fallback_results) return JSONResponse( content={ "retrievalResults": fallback_results, "finalPrompt": fallback_prompt, "diagnostics": { "error": str(e), "fallback": True, }, } ) def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str: """ Build the final prompt from query and retrieval results. """ if not retrieval_results: return f"""用户问题:{query} 未找到相关检索结果,请基于通用知识回答用户问题。""" evidence_text = "\n".join([ f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}" for i, hit in enumerate(retrieval_results[:5]) ]) return f"""基于以下检索到的信息,回答用户问题: 用户问题:{query} 检索结果: {evidence_text} 请基于以上信息生成专业、准确的回答。""" def _get_fallback_results(query: str) -> list[dict]: """ Provide fallback results when retrieval fails. """ return [ { "content": "检索服务暂时不可用,这是模拟结果。", "score": 0.5, "source": "fallback", } ]