330 lines
11 KiB
Python
330 lines
11 KiB
Python
"""
|
|
RAG Lab endpoints for debugging and experimentation.
|
|
[AC-ASA-05, AC-ASA-19, AC-ASA-20, AC-ASA-21, AC-ASA-22] RAG experiment with AI output.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Body, Depends
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.exceptions import MissingTenantIdException
|
|
from app.core.prompts import build_user_prompt_with_evidence, format_evidence_for_prompt
|
|
from app.core.tenant import get_tenant_id
|
|
from app.models import ErrorResponse
|
|
from app.services.llm.factory import get_llm_config_manager
|
|
from app.services.retrieval.base import RetrievalContext
|
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
|
|
|
|
|
|
def get_current_tenant_id() -> str:
|
|
"""Dependency to get current tenant ID or raise exception."""
|
|
tenant_id = get_tenant_id()
|
|
if not tenant_id:
|
|
raise MissingTenantIdException()
|
|
return tenant_id
|
|
|
|
|
|
class RAGExperimentRequest(BaseModel):
|
|
query: str = Field(..., description="Query text for retrieval")
|
|
kb_ids: list[str] | None = Field(default=None, description="Knowledge base IDs to search")
|
|
top_k: int = Field(default=5, description="Number of results to retrieve")
|
|
score_threshold: float = Field(default=0.5, description="Minimum similarity score")
|
|
generate_response: bool = Field(default=True, description="Whether to generate AI response")
|
|
llm_provider: str | None = Field(default=None, description="Specific LLM provider to use")
|
|
|
|
|
|
class AIResponse(BaseModel):
|
|
content: str
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
latency_ms: float = 0
|
|
model: str = ""
|
|
|
|
|
|
class RAGExperimentResult(BaseModel):
|
|
query: str
|
|
retrieval_results: list[dict] = []
|
|
final_prompt: str = ""
|
|
ai_response: AIResponse | None = None
|
|
total_latency_ms: float = 0
|
|
diagnostics: dict[str, Any] = {}
|
|
|
|
|
|
@router.post(
|
|
"/experiments/run",
|
|
operation_id="runRagExperiment",
|
|
summary="Run RAG debugging experiment with AI output",
|
|
description="[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Trigger RAG experiment with retrieval, prompt generation, and AI response.",
|
|
responses={
|
|
200: {"description": "Experiment results with retrieval, prompt, and AI response"},
|
|
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
403: {"description": "Forbidden", "model": ErrorResponse},
|
|
},
|
|
)
|
|
async def run_rag_experiment(
|
|
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
|
request: RAGExperimentRequest = Body(...),
|
|
) -> JSONResponse:
|
|
"""
|
|
[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Run RAG experiment and return retrieval results with AI response.
|
|
"""
|
|
start_time = time.time()
|
|
|
|
logger.info(
|
|
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
|
|
f"query={request.query[:50]}..., kb_ids={request.kb_ids}, "
|
|
f"generate_response={request.generate_response}"
|
|
)
|
|
|
|
settings = get_settings()
|
|
top_k = request.top_k or settings.rag_top_k
|
|
threshold = request.score_threshold or settings.rag_score_threshold
|
|
|
|
try:
|
|
# Use optimized retriever with RAG enhancements
|
|
retriever = await get_optimized_retriever()
|
|
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id=tenant_id,
|
|
query=request.query,
|
|
session_id="rag_experiment",
|
|
channel_type="admin",
|
|
metadata={"kb_ids": request.kb_ids},
|
|
)
|
|
|
|
result = await retriever.retrieve(retrieval_ctx)
|
|
|
|
retrieval_results = [
|
|
{
|
|
"content": hit.text,
|
|
"score": hit.score,
|
|
"source": hit.source,
|
|
"metadata": hit.metadata,
|
|
}
|
|
for hit in result.hits
|
|
]
|
|
|
|
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
|
|
|
logger.info(
|
|
f"[AC-ASA-05] RAG retrieval complete: hits={len(retrieval_results)}, "
|
|
f"max_score={result.max_score:.3f}"
|
|
)
|
|
|
|
ai_response = None
|
|
if request.generate_response:
|
|
ai_response = await _generate_ai_response(
|
|
final_prompt,
|
|
provider=request.llm_provider,
|
|
)
|
|
|
|
total_latency_ms = (time.time() - start_time) * 1000
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"query": request.query,
|
|
"retrieval_results": retrieval_results,
|
|
"final_prompt": final_prompt,
|
|
"ai_response": ai_response.model_dump() if ai_response else None,
|
|
"total_latency_ms": round(total_latency_ms, 2),
|
|
"diagnostics": result.diagnostics,
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
|
|
|
|
fallback_results = _get_fallback_results(request.query)
|
|
fallback_prompt = _build_final_prompt(request.query, fallback_results)
|
|
|
|
ai_response = None
|
|
if request.generate_response:
|
|
ai_response = await _generate_ai_response(
|
|
fallback_prompt,
|
|
provider=request.llm_provider,
|
|
)
|
|
|
|
total_latency_ms = (time.time() - start_time) * 1000
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"query": request.query,
|
|
"retrieval_results": fallback_results,
|
|
"final_prompt": fallback_prompt,
|
|
"ai_response": ai_response.model_dump() if ai_response else None,
|
|
"total_latency_ms": round(total_latency_ms, 2),
|
|
"diagnostics": {
|
|
"error": str(e),
|
|
"fallback": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/experiments/stream",
|
|
operation_id="runRagExperimentStream",
|
|
summary="Run RAG experiment with streaming AI output",
|
|
description="[AC-ASA-20] Trigger RAG experiment with SSE streaming for AI response.",
|
|
responses={
|
|
200: {"description": "SSE stream with retrieval results and AI response"},
|
|
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
403: {"description": "Forbidden", "model": ErrorResponse},
|
|
},
|
|
)
|
|
async def run_rag_experiment_stream(
|
|
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
|
request: RAGExperimentRequest = Body(...),
|
|
) -> StreamingResponse:
|
|
"""
|
|
[AC-ASA-20] Run RAG experiment with SSE streaming for AI response.
|
|
"""
|
|
logger.info(
|
|
f"[AC-ASA-20] Running RAG experiment stream: tenant={tenant_id}, "
|
|
f"query={request.query[:50]}..."
|
|
)
|
|
|
|
settings = get_settings()
|
|
top_k = request.top_k or settings.rag_top_k
|
|
threshold = request.score_threshold or settings.rag_score_threshold
|
|
|
|
async def event_generator():
|
|
try:
|
|
# Use optimized retriever with RAG enhancements
|
|
retriever = await get_optimized_retriever()
|
|
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id=tenant_id,
|
|
query=request.query,
|
|
session_id="rag_experiment_stream",
|
|
channel_type="admin",
|
|
metadata={"kb_ids": request.kb_ids},
|
|
)
|
|
|
|
result = await retriever.retrieve(retrieval_ctx)
|
|
|
|
retrieval_results = [
|
|
{
|
|
"content": hit.text,
|
|
"score": hit.score,
|
|
"source": hit.source,
|
|
"metadata": hit.metadata,
|
|
}
|
|
for hit in result.hits
|
|
]
|
|
|
|
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
|
|
|
logger.info("[AC-ASA-20] ========== RAG LAB STREAM FULL PROMPT ==========")
|
|
logger.info(f"[AC-ASA-20] Prompt length: {len(final_prompt)}")
|
|
logger.info(f"[AC-ASA-20] Prompt content:\n{final_prompt}")
|
|
logger.info("[AC-ASA-20] ==============================================")
|
|
|
|
yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n"
|
|
|
|
yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n"
|
|
|
|
if request.generate_response:
|
|
manager = get_llm_config_manager()
|
|
client = manager.get_client()
|
|
|
|
full_content = ""
|
|
async for chunk in client.stream_generate(
|
|
messages=[{"role": "user", "content": final_prompt}],
|
|
):
|
|
if chunk.delta:
|
|
full_content += chunk.delta
|
|
yield f"event: message\ndata: {json.dumps({'delta': chunk.delta})}\n\n"
|
|
|
|
yield f"event: final\ndata: {json.dumps({'content': full_content, 'finish_reason': 'stop'})}\n\n"
|
|
else:
|
|
yield f"event: final\ndata: {json.dumps({'content': '', 'finish_reason': 'skipped'})}\n\n"
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AC-ASA-20] RAG experiment stream failed: {e}")
|
|
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
async def _generate_ai_response(
|
|
prompt: str,
|
|
provider: str | None = None,
|
|
) -> AIResponse | None:
|
|
"""
|
|
[AC-ASA-19, AC-ASA-21] Generate AI response from prompt.
|
|
"""
|
|
import time
|
|
|
|
logger.info("[AC-ASA-19] ========== RAG LAB FULL PROMPT ==========")
|
|
logger.info(f"[AC-ASA-19] Prompt length: {len(prompt)}")
|
|
logger.info(f"[AC-ASA-19] Prompt content:\n{prompt}")
|
|
logger.info("[AC-ASA-19] ==========================================")
|
|
|
|
try:
|
|
manager = get_llm_config_manager()
|
|
client = manager.get_client()
|
|
|
|
start_time = time.time()
|
|
response = await client.generate(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
)
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
|
|
return AIResponse(
|
|
content=response.content,
|
|
prompt_tokens=response.usage.get("prompt_tokens", 0),
|
|
completion_tokens=response.usage.get("completion_tokens", 0),
|
|
total_tokens=response.usage.get("total_tokens", 0),
|
|
latency_ms=round(latency_ms, 2),
|
|
model=response.model,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AC-ASA-19] AI response generation failed: {e}")
|
|
return AIResponse(
|
|
content=f"AI 响应生成失败: {str(e)}",
|
|
latency_ms=0,
|
|
)
|
|
|
|
|
|
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
|
|
"""
|
|
Build the final prompt from query and retrieval results.
|
|
Uses shared prompt configuration for consistency with orchestrator.
|
|
"""
|
|
evidence_text = format_evidence_for_prompt(retrieval_results, max_results=5, max_content_length=500)
|
|
return build_user_prompt_with_evidence(query, evidence_text)
|
|
|
|
|
|
def _get_fallback_results(query: str) -> list[dict]:
|
|
"""
|
|
Provide fallback results when retrieval fails.
|
|
"""
|
|
return [
|
|
{
|
|
"content": "检索服务暂时不可用,这是模拟结果。",
|
|
"score": 0.5,
|
|
"source": "fallback",
|
|
}
|
|
]
|