ai-robot-core/ai-service/app/api/admin/flow_test.py

403 lines
12 KiB
Python
Raw Normal View History

"""
Flow test API for AI Service Admin.
[AC-AISVC-93~AC-AISVC-95] Complete 12-step flow execution testing.
"""
import logging
import uuid
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.models.entities import FlowTestRecord, FlowTestRecordStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/test", tags=["Flow Test"])
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header."""
if not x_tenant_id:
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
return x_tenant_id
class FlowExecutionRequest(BaseModel):
"""Request for flow execution test."""
message: str
session_id: str | None = None
user_id: str | None = None
enable_flow: bool = True
enable_intent: bool = True
enable_rag: bool = True
enable_guardrail: bool = True
enable_memory: bool = True
compare_mode: bool = False
class FlowExecutionResponse(BaseModel):
"""Response for flow execution test."""
test_id: str
session_id: str
status: str
steps: list[dict[str, Any]]
final_response: dict[str, Any] | None
total_duration_ms: int
created_at: str
@router.post(
"/flow-execution",
operation_id="executeFlowTest",
summary="Execute complete 12-step flow",
description="[AC-AISVC-93] Execute complete 12-step generation flow with detailed step logging.",
)
async def execute_flow_test(
request: FlowExecutionRequest,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> FlowExecutionResponse:
"""
[AC-AISVC-93] Execute complete 12-step flow for testing.
Steps:
1. InputScanner - Scan input for forbidden words
2. FlowEngine - Check if flow is active
3. IntentRouter - Match intent rules
4. QueryRewriter - Rewrite query for better retrieval
5. MultiKBRetrieval - Retrieve from multiple knowledge bases
6. ResultRanker - Rank and filter results
7. PromptBuilder - Build prompt from template
8. LLMGenerate - Generate response via LLM
9. OutputFilter - Filter output for forbidden words
10. Confidence - Calculate confidence score
11. Memory - Store conversation in memory
12. Response - Return final response
"""
import time
from app.models import ChatRequest, ChannelType
from app.services.llm.factory import get_llm_config_manager
from app.services.memory import MemoryService
from app.services.orchestrator import OrchestratorService
from app.services.retrieval.optimized_retriever import get_optimized_retriever
logger.info(
f"[AC-AISVC-93] Executing flow test for tenant={tenant_id}, "
f"message={request.message[:50]}..."
)
test_session_id = request.session_id or f"test_{uuid.uuid4().hex[:8]}"
start_time = time.time()
memory_service = MemoryService(session)
llm_config_manager = get_llm_config_manager()
llm_client = llm_config_manager.get_client()
retriever = await get_optimized_retriever()
orchestrator = OrchestratorService(
llm_client=llm_client,
memory_service=memory_service,
retriever=retriever,
)
try:
chat_request = ChatRequest(
session_id=test_session_id,
current_message=request.message,
channel_type=ChannelType.WECHAT,
history=[],
)
result = await orchestrator.generate(
tenant_id=tenant_id,
request=chat_request,
)
steps = result.metadata.get("execution_steps", []) if result.metadata else []
total_duration_ms = int((time.time() - start_time) * 1000)
has_failure = any(s.get("status") == "failed" for s in steps)
has_partial = any(s.get("status") == "skipped" for s in steps)
if has_failure:
status = FlowTestRecordStatus.FAILED.value
elif has_partial:
status = FlowTestRecordStatus.PARTIAL.value
else:
status = FlowTestRecordStatus.SUCCESS.value
test_record = FlowTestRecord(
tenant_id=tenant_id,
session_id=test_session_id,
status=status,
steps=steps,
final_response={
"reply": result.reply,
"confidence": result.confidence,
"should_transfer": result.should_transfer,
},
total_duration_ms=total_duration_ms,
)
try:
session.add(test_record)
await session.commit()
await session.refresh(test_record)
except Exception as db_error:
logger.warning(f"Failed to save test record: {db_error}")
await session.rollback()
logger.info(
f"[AC-AISVC-93] Flow test completed: id={test_record.id}, "
f"status={status}, duration={total_duration_ms}ms"
)
return FlowExecutionResponse(
test_id=str(test_record.id),
session_id=test_session_id,
status=status,
steps=steps,
final_response=test_record.final_response,
total_duration_ms=total_duration_ms,
created_at=test_record.created_at.isoformat(),
)
except Exception as e:
logger.error(f"[AC-AISVC-93] Flow test failed: {e}")
total_duration_ms = int((time.time() - start_time) * 1000)
await session.rollback()
test_record = FlowTestRecord(
tenant_id=tenant_id,
session_id=test_session_id,
status=FlowTestRecordStatus.FAILED.value,
steps=[{
"step": 0,
"name": "Error",
"status": "failed",
"error": str(e),
}],
final_response=None,
total_duration_ms=total_duration_ms,
)
session.add(test_record)
await session.commit()
await session.refresh(test_record)
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/flow-execution/{test_id}",
operation_id="getFlowTestResult",
summary="Get flow test result",
description="[AC-AISVC-94] Get detailed result of a flow execution test.",
)
async def get_flow_test_result(
test_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-94] Get detailed result of a flow execution test.
Returns step-by-step execution details for debugging.
"""
logger.info(
f"[AC-AISVC-94] Getting flow test result for tenant={tenant_id}, "
f"test_id={test_id}"
)
stmt = select(FlowTestRecord).where(
FlowTestRecord.id == test_id,
FlowTestRecord.tenant_id == tenant_id,
)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(status_code=404, detail="Test record not found")
return {
"testId": str(record.id),
"sessionId": record.session_id,
"status": record.status,
"steps": record.steps,
"finalResponse": record.final_response,
"totalDurationMs": record.total_duration_ms,
"createdAt": record.created_at.isoformat(),
}
@router.get(
"/flow-executions",
operation_id="listFlowTests",
summary="List flow test records",
description="[AC-AISVC-95] List flow test execution records.",
)
async def list_flow_tests(
tenant_id: str = Depends(get_tenant_id),
session_id: str | None = Query(None, description="Filter by session ID"),
status: str | None = Query(None, description="Filter by status"),
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(20, ge=1, le=100, description="Page size"),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-95] List flow test execution records.
Records are retained for 7 days.
"""
logger.info(
f"[AC-AISVC-95] Listing flow tests for tenant={tenant_id}, "
f"session={session_id}, page={page}"
)
stmt = select(FlowTestRecord).where(
FlowTestRecord.tenant_id == tenant_id,
)
if session_id:
stmt = stmt.where(FlowTestRecord.session_id == session_id)
if status:
stmt = stmt.where(FlowTestRecord.status == status)
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = await session.execute(count_stmt)
total = total_result.scalar() or 0
stmt = stmt.order_by(desc(FlowTestRecord.created_at))
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(stmt)
records = result.scalars().all()
return {
"data": [
{
"testId": str(r.id),
"sessionId": r.session_id,
"status": r.status,
"stepCount": len(r.steps),
"totalDurationMs": r.total_duration_ms,
"createdAt": r.created_at.isoformat(),
}
for r in records
],
"page": page,
"pageSize": page_size,
"total": total,
}
class CompareRequest(BaseModel):
"""Request for comparison test."""
message: str
baseline_config: dict[str, Any] | None = None
test_config: dict[str, Any] | None = None
@router.post(
"/compare",
operation_id="compareFlowTest",
summary="Compare two flow executions",
description="[AC-AISVC-95] Compare baseline and test configurations.",
)
async def compare_flow_test(
request: CompareRequest,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-95] Compare two flow executions with different configurations.
Useful for:
- A/B testing prompt templates
- Comparing RAG retrieval strategies
- Testing guardrail effectiveness
"""
import time
from app.models import ChatRequest, ChannelType
from app.services.llm.factory import get_llm_config_manager
from app.services.memory import MemoryService
from app.services.orchestrator import OrchestratorService
from app.services.retrieval.optimized_retriever import get_optimized_retriever
logger.info(
f"[AC-AISVC-95] Running comparison test for tenant={tenant_id}"
)
baseline_session_id = f"compare_baseline_{uuid.uuid4().hex[:8]}"
test_session_id = f"compare_test_{uuid.uuid4().hex[:8]}"
memory_service = MemoryService(session)
llm_config_manager = get_llm_config_manager()
llm_client = llm_config_manager.get_client()
retriever = await get_optimized_retriever()
orchestrator = OrchestratorService(
llm_client=llm_client,
memory_service=memory_service,
retriever=retriever,
)
baseline_chat_request = ChatRequest(
session_id=baseline_session_id,
current_message=request.message,
channel_type=ChannelType.WECHAT,
history=[],
)
baseline_start = time.time()
baseline_result = await orchestrator.generate(
tenant_id=tenant_id,
request=baseline_chat_request,
)
baseline_duration = int((time.time() - baseline_start) * 1000)
test_chat_request = ChatRequest(
session_id=test_session_id,
current_message=request.message,
channel_type=ChannelType.WECHAT,
history=[],
)
test_start = time.time()
test_result = await orchestrator.generate(
tenant_id=tenant_id,
request=test_chat_request,
)
test_duration = int((time.time() - test_start) * 1000)
return {
"baseline": {
"sessionId": baseline_session_id,
"reply": baseline_result.reply,
"confidence": baseline_result.confidence,
"durationMs": baseline_duration,
"steps": baseline_result.metadata.get("execution_steps", []) if baseline_result.metadata else [],
},
"test": {
"sessionId": test_session_id,
"reply": test_result.reply,
"confidence": test_result.confidence,
"durationMs": test_duration,
"steps": test_result.metadata.get("execution_steps", []) if test_result.metadata else [],
},
"comparison": {
"durationDiffMs": test_duration - baseline_duration,
"confidenceDiff": (test_result.confidence or 0) - (baseline_result.confidence or 0),
},
}