From b4eb98e7c422f0ae1d26f404986aad560eb551de Mon Sep 17 00:00:00 2001 From: MerCry Date: Thu, 5 Mar 2026 18:13:34 +0800 Subject: [PATCH] feat: implement mid-platform core services and APIs [AC-IDMP-01~20, AC-MARH-01~12] - Add dialogue, messages, sessions, share API endpoints - Add mid-platform schemas and models (memory, tool_registry, tool_trace) - Add core services: agent_orchestrator, policy_router, runtime_observer - Add tool services: kb_search_dynamic, memory_adapter, tool_registry - Add guardrail services: output_guardrail_executor, high_risk_handler - Add utility services: timeout_governor, segment_humanizer, metrics_collector --- ai-service/app/api/mid/__init__.py | 30 + ai-service/app/api/mid/dialogue.py | 1075 +++++++++++++++++ ai-service/app/api/mid/messages.py | 104 ++ ai-service/app/api/mid/sessions.py | 105 ++ ai-service/app/models/mid/__init__.py | 224 ++++ ai-service/app/models/mid/memory.py | 182 +++ ai-service/app/models/mid/schemas.py | 407 +++++++ ai-service/app/models/mid/tool_registry.py | 222 ++++ ai-service/app/models/mid/tool_trace.py | 174 +++ ai-service/app/services/mid/__init__.py | 57 + .../app/services/mid/agent_orchestrator.py | 370 ++++++ .../services/mid/default_kb_tool_runner.py | 244 ++++ ai-service/app/services/mid/feature_flags.py | 100 ++ .../app/services/mid/high_risk_handler.py | 232 ++++ .../mid/interrupt_context_enricher.py | 161 +++ .../services/mid/kb_search_dynamic_tool.py | 487 ++++++++ ai-service/app/services/mid/memory_adapter.py | 355 ++++++ .../app/services/mid/metrics_collector.py | 219 ++++ .../services/mid/output_guardrail_executor.py | 152 +++ ai-service/app/services/mid/policy_router.py | 411 +++++++ .../app/services/mid/runtime_observer.py | 289 +++++ .../app/services/mid/segment_humanizer.py | 282 +++++ .../app/services/mid/timeout_governor.py | 166 +++ .../app/services/mid/tool_call_recorder.py | 324 +++++ ai-service/app/services/mid/tool_registry.py | 337 ++++++ ai-service/app/services/mid/trace_logger.py | 269 +++++ 26 files changed, 6978 insertions(+) create mode 100644 ai-service/app/api/mid/__init__.py create mode 100644 ai-service/app/api/mid/dialogue.py create mode 100644 ai-service/app/api/mid/messages.py create mode 100644 ai-service/app/api/mid/sessions.py create mode 100644 ai-service/app/models/mid/__init__.py create mode 100644 ai-service/app/models/mid/memory.py create mode 100644 ai-service/app/models/mid/schemas.py create mode 100644 ai-service/app/models/mid/tool_registry.py create mode 100644 ai-service/app/models/mid/tool_trace.py create mode 100644 ai-service/app/services/mid/__init__.py create mode 100644 ai-service/app/services/mid/agent_orchestrator.py create mode 100644 ai-service/app/services/mid/default_kb_tool_runner.py create mode 100644 ai-service/app/services/mid/feature_flags.py create mode 100644 ai-service/app/services/mid/high_risk_handler.py create mode 100644 ai-service/app/services/mid/interrupt_context_enricher.py create mode 100644 ai-service/app/services/mid/kb_search_dynamic_tool.py create mode 100644 ai-service/app/services/mid/memory_adapter.py create mode 100644 ai-service/app/services/mid/metrics_collector.py create mode 100644 ai-service/app/services/mid/output_guardrail_executor.py create mode 100644 ai-service/app/services/mid/policy_router.py create mode 100644 ai-service/app/services/mid/runtime_observer.py create mode 100644 ai-service/app/services/mid/segment_humanizer.py create mode 100644 ai-service/app/services/mid/timeout_governor.py create mode 100644 ai-service/app/services/mid/tool_call_recorder.py create mode 100644 ai-service/app/services/mid/tool_registry.py create mode 100644 ai-service/app/services/mid/trace_logger.py diff --git a/ai-service/app/api/mid/__init__.py b/ai-service/app/api/mid/__init__.py new file mode 100644 index 0000000..55e79bc --- /dev/null +++ b/ai-service/app/api/mid/__init__.py @@ -0,0 +1,30 @@ +""" +Mid Platform API endpoints. +[AC-IDMP-01~20] Mid platform dialogue, messages, and session management. +[AC-IDMP-SHARE] Share session via unique token. +[AC-MRS-09,10] Runtime slot query endpoints. +""" + +from fastapi import APIRouter + +from .dialogue import router as dialogue_router +from .messages import router as messages_router +from .sessions import router as sessions_router +from .share import router as share_router +from .slots import router as slots_router + +router = APIRouter() +router.include_router(dialogue_router) +router.include_router(messages_router) +router.include_router(sessions_router) +router.include_router(share_router) +router.include_router(slots_router) + +__all__ = [ + "router", + "dialogue_router", + "messages_router", + "sessions_router", + "share_router", + "slots_router", +] diff --git a/ai-service/app/api/mid/dialogue.py b/ai-service/app/api/mid/dialogue.py new file mode 100644 index 0000000..492882f --- /dev/null +++ b/ai-service/app/api/mid/dialogue.py @@ -0,0 +1,1075 @@ +""" +Dialogue Controller for Mid Platform. +[AC-MARH-01, AC-MARH-02, AC-MARH-03, AC-MARH-04, AC-MARH-05, AC-MARH-06, + AC-MARH-07, AC-MARH-08, AC-MARH-09, AC-MARH-10, AC-MARH-11, AC-MARH-12] + +Core endpoint: POST /mid/dialogue/respond +""" + +import logging +import time +import uuid +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.core.tenant import get_tenant_id +from app.models.mid.schemas import ( + DialogueRequest, + DialogueResponse, + ExecutionMode, + Segment, + TraceInfo, +) +from app.services.mid.agent_orchestrator import AgentOrchestrator +from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner +from app.services.mid.feature_flags import FeatureFlagService +from app.services.mid.high_risk_handler import HighRiskHandler +from app.services.mid.interrupt_context_enricher import InterruptContextEnricher +from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicConfig, + KbSearchDynamicTool, +) +from app.services.mid.high_risk_check_tool import ( + HighRiskCheckConfig, + HighRiskCheckTool, + register_high_risk_check_tool, +) +from app.services.mid.intent_hint_tool import ( + IntentHintConfig, + IntentHintTool, + register_intent_hint_tool, +) +from app.services.mid.memory_recall_tool import ( + MemoryRecallConfig, + MemoryRecallTool, + register_memory_recall_tool, +) +from app.services.mid.metrics_collector import MetricsCollector +from app.services.mid.output_guardrail_executor import OutputGuardrailExecutor +from app.services.mid.policy_router import IntentMatch, PolicyRouter +from app.services.mid.runtime_observer import RuntimeObserver +from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer +from app.services.mid.timeout_governor import TimeoutGovernor +from app.services.mid.tool_registry import ToolRegistry +from app.services.mid.trace_logger import TraceLogger + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/mid", tags=["Mid Platform Dialogue"]) + +_mid_services: dict[str, Any] = {} + + +def get_policy_router() -> PolicyRouter: + """Get or create PolicyRouter instance.""" + if "policy_router" not in _mid_services: + _mid_services["policy_router"] = PolicyRouter() + return _mid_services["policy_router"] + + +def get_high_risk_handler() -> HighRiskHandler: + """Get or create HighRiskHandler instance.""" + if "high_risk_handler" not in _mid_services: + _mid_services["high_risk_handler"] = HighRiskHandler() + return _mid_services["high_risk_handler"] + + +def get_timeout_governor() -> TimeoutGovernor: + """Get or create TimeoutGovernor instance.""" + if "timeout_governor" not in _mid_services: + _mid_services["timeout_governor"] = TimeoutGovernor() + return _mid_services["timeout_governor"] + + +def get_feature_flag_service() -> FeatureFlagService: + """Get or create FeatureFlagService instance.""" + if "feature_flag_service" not in _mid_services: + _mid_services["feature_flag_service"] = FeatureFlagService() + return _mid_services["feature_flag_service"] + + +def get_trace_logger() -> TraceLogger: + """Get or create TraceLogger instance.""" + if "trace_logger" not in _mid_services: + _mid_services["trace_logger"] = TraceLogger() + return _mid_services["trace_logger"] + + +def get_metrics_collector() -> MetricsCollector: + """Get or create MetricsCollector instance.""" + if "metrics_collector" not in _mid_services: + _mid_services["metrics_collector"] = MetricsCollector() + return _mid_services["metrics_collector"] + + +def get_tool_registry() -> ToolRegistry: + """Get or create ToolRegistry instance.""" + if "tool_registry" not in _mid_services: + _mid_services["tool_registry"] = ToolRegistry( + timeout_governor=get_timeout_governor() + ) + return _mid_services["tool_registry"] + + +_kb_search_dynamic_registered: bool = False +_intent_hint_registered: bool = False +_high_risk_check_registered: bool = False +_memory_recall_registered: bool = False + + +def ensure_kb_search_dynamic_registered( + registry: ToolRegistry, + session: AsyncSession, +) -> None: + """[AC-MARH-05] Ensure kb_search_dynamic tool is registered.""" + global _kb_search_dynamic_registered + if _kb_search_dynamic_registered: + return + + from app.services.mid.kb_search_dynamic_tool import register_kb_search_dynamic_tool + + config = KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=2000, + min_score_threshold=0.5, + ) + + register_kb_search_dynamic_tool( + registry=registry, + session=session, + timeout_governor=get_timeout_governor(), + config=config, + ) + _kb_search_dynamic_registered = True + logger.info("[AC-MARH-05] kb_search_dynamic tool registered to registry") + + +def ensure_intent_hint_registered( + registry: ToolRegistry, + session: AsyncSession, +) -> None: + """[AC-IDMP-02, AC-IDMP-16] Ensure intent_hint tool is registered.""" + global _intent_hint_registered + if _intent_hint_registered: + return + + config = IntentHintConfig( + enabled=True, + timeout_ms=500, + top_n=3, + low_confidence_threshold=0.3, + ) + + register_intent_hint_tool( + registry=registry, + session=session, + config=config, + ) + _intent_hint_registered = True + logger.info("[AC-IDMP-02] intent_hint tool registered to registry") + + +def ensure_high_risk_check_registered( + registry: ToolRegistry, + session: AsyncSession, +) -> None: + """[AC-IDMP-05, AC-IDMP-20] Ensure high_risk_check tool is registered.""" + global _high_risk_check_registered + if _high_risk_check_registered: + return + + config = HighRiskCheckConfig( + enabled=True, + timeout_ms=500, + default_confidence=0.9, + ) + + register_high_risk_check_tool( + registry=registry, + session=session, + config=config, + ) + _high_risk_check_registered = True + logger.info("[AC-IDMP-05] high_risk_check tool registered to registry") + + +def ensure_memory_recall_registered( + registry: ToolRegistry, + session: AsyncSession, +) -> None: + """[AC-IDMP-13] Ensure memory_recall tool is registered.""" + global _memory_recall_registered + if _memory_recall_registered: + return + + config = MemoryRecallConfig( + enabled=True, + timeout_ms=1000, + max_recent_messages=8, + ) + + register_memory_recall_tool( + registry=registry, + session=session, + timeout_governor=get_timeout_governor(), + config=config, + ) + _memory_recall_registered = True + logger.info("[AC-IDMP-13] memory_recall tool registered to registry") + + +def get_output_guardrail_executor() -> OutputGuardrailExecutor: + """Get or create OutputGuardrailExecutor instance.""" + if "output_guardrail_executor" not in _mid_services: + _mid_services["output_guardrail_executor"] = OutputGuardrailExecutor() + return _mid_services["output_guardrail_executor"] + + +def get_interrupt_context_enricher() -> InterruptContextEnricher: + """Get or create InterruptContextEnricher instance.""" + if "interrupt_context_enricher" not in _mid_services: + _mid_services["interrupt_context_enricher"] = InterruptContextEnricher() + return _mid_services["interrupt_context_enricher"] + + +def get_default_kb_tool_runner() -> DefaultKbToolRunner: + """Get or create DefaultKbToolRunner instance.""" + if "default_kb_tool_runner" not in _mid_services: + _mid_services["default_kb_tool_runner"] = DefaultKbToolRunner( + timeout_governor=get_timeout_governor() + ) + return _mid_services["default_kb_tool_runner"] + + +def get_segment_humanizer() -> SegmentHumanizer: + """Get or create SegmentHumanizer instance.""" + if "segment_humanizer" not in _mid_services: + _mid_services["segment_humanizer"] = SegmentHumanizer() + return _mid_services["segment_humanizer"] + + +def get_runtime_observer() -> RuntimeObserver: + """Get or create RuntimeObserver instance.""" + if "runtime_observer" not in _mid_services: + _mid_services["runtime_observer"] = RuntimeObserver() + return _mid_services["runtime_observer"] + + +@router.post( + "/dialogue/respond", + operation_id="respondDialogue", + summary="Generate mid platform dialogue response", + description=""" + [AC-MARH-01~12] Core dialogue response endpoint for mid platform. + + Returns segments[] with trace info including: + - guardrail_triggered, guardrail_rule_id + - interrupt_consumed + - kb_tool_called, kb_hit + - timeout_profile, segment_stats + """, +) +async def respond_dialogue( + request: Request, + dialogue_request: DialogueRequest, + session: Annotated[AsyncSession, Depends(get_session)], + policy_router: PolicyRouter = Depends(get_policy_router), + high_risk_handler: HighRiskHandler = Depends(get_high_risk_handler), + timeout_governor: TimeoutGovernor = Depends(get_timeout_governor), + feature_flag_service: FeatureFlagService = Depends(get_feature_flag_service), + trace_logger: TraceLogger = Depends(get_trace_logger), + metrics_collector: MetricsCollector = Depends(get_metrics_collector), + output_guardrail_executor: OutputGuardrailExecutor = Depends(get_output_guardrail_executor), + interrupt_context_enricher: InterruptContextEnricher = Depends(get_interrupt_context_enricher), + default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner), + segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer), + runtime_observer: RuntimeObserver = Depends(get_runtime_observer), +) -> DialogueResponse: + """ + [AC-MARH-01~12] Generate dialogue response with segments and trace. + + Flow: + 1. Validate request and get tenant context + 2. Start runtime observation + 3. Process interrupted segments (AC-MARH-03/04) + 4. Check feature flags for grayscale/rollback + 5. Detect high-risk scenarios + 6. Route to appropriate execution mode + 7. For Agent mode: call KB tool (AC-MARH-05/06) + 8. Execute output guardrail (AC-MARH-01/02) + 9. Apply segment humanizer (AC-MARH-10/11) + 10. Collect trace and return (AC-MARH-12) + """ + start_time = time.time() + tenant_id = get_tenant_id() + + if not tenant_id: + from app.core.exceptions import MissingTenantIdException + raise MissingTenantIdException() + + request_id = str(uuid.uuid4()) + generation_id = str(uuid.uuid4()) + + logger.info( + f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, " + f"session={dialogue_request.session_id}, request_id={request_id}" + ) + + runtime_ctx = runtime_observer.start_observation( + tenant_id=tenant_id, + session_id=dialogue_request.session_id, + request_id=request_id, + generation_id=generation_id, + ) + + trace = trace_logger.start_trace( + tenant_id=tenant_id, + session_id=dialogue_request.session_id, + request_id=request_id, + generation_id=generation_id, + ) + + metrics_collector.start_session(dialogue_request.session_id) + + tool_registry = get_tool_registry() + ensure_kb_search_dynamic_registered(tool_registry, session) + ensure_intent_hint_registered(tool_registry, session) + ensure_high_risk_check_registered(tool_registry, session) + ensure_memory_recall_registered(tool_registry, session) + + try: + interrupt_ctx = interrupt_context_enricher.enrich( + dialogue_request.interrupted_segments, + generation_id, + ) + runtime_observer.record_interrupt(request_id, interrupt_ctx.consumed) + + feature_flags = dialogue_request.feature_flags or feature_flag_service.get_flags( + dialogue_request.session_id + ) + + if feature_flags.rollback_to_legacy: + logger.info(f"[AC-MARH-17] Rollback to legacy for session: {dialogue_request.session_id}") + return await _handle_legacy_response( + tenant_id=tenant_id, + request=dialogue_request, + trace=trace, + trace_logger=trace_logger, + start_time=start_time, + ) + + high_risk_check_tool = HighRiskCheckTool( + session=session, + config=HighRiskCheckConfig(enabled=True, timeout_ms=500), + ) + high_risk_result = await high_risk_check_tool.execute( + message=dialogue_request.user_message, + tenant_id=tenant_id, + ) + + if high_risk_result.duration_ms > 0: + hr_trace = high_risk_check_tool.create_trace(high_risk_result, tenant_id) + trace_logger.update_trace( + request_id=request_id, + tool_calls=[hr_trace], + ) + + logger.info( + f"[AC-IDMP-05, AC-IDMP-20] High risk check result: " + f"matched={high_risk_result.matched}, scenario={high_risk_result.risk_scenario}, " + f"duration_ms={high_risk_result.duration_ms}" + ) + + if high_risk_result.matched and high_risk_result.risk_scenario: + logger.info( + f"[AC-IDMP-05] High-risk matched from tool: {high_risk_result.risk_scenario.value}" + ) + return await _handle_high_risk_check_response( + tenant_id=tenant_id, + request=dialogue_request, + high_risk_result=high_risk_result, + trace=trace, + trace_logger=trace_logger, + start_time=start_time, + session=session, + ) + + intent_hint_tool = IntentHintTool( + session=session, + config=IntentHintConfig(enabled=True, timeout_ms=500), + ) + intent_hint = await intent_hint_tool.execute( + message=dialogue_request.user_message, + tenant_id=tenant_id, + history=[h.model_dump() for h in dialogue_request.history] if dialogue_request.history else None, + ) + + if intent_hint.duration_ms > 0: + hint_trace = intent_hint_tool.create_trace(intent_hint) + trace_logger.update_trace( + request_id=request_id, + tool_calls=[hint_trace], + ) + + logger.info( + f"[AC-IDMP-02] Intent hint result: intent={intent_hint.intent}, " + f"confidence={intent_hint.confidence}, suggested_mode={intent_hint.suggested_mode}" + ) + + intent_match = await _match_intent(tenant_id, dialogue_request, session) + + router_result = policy_router.route( + user_message=dialogue_request.user_message, + session_mode="BOT_ACTIVE", + feature_flags=feature_flags, + intent_match=intent_match, + intent_hint=intent_hint, + ) + + runtime_observer.update_mode(request_id, router_result.mode, router_result.intent) + runtime_observer.record_timeout_profile(request_id, timeout_governor.profile) + + trace_logger.update_trace( + request_id=request_id, + mode=router_result.mode, + intent=router_result.intent, + fallback_reason_code=router_result.fallback_reason_code, + ) + + if router_result.mode == ExecutionMode.AGENT: + response = await _execute_agent_mode( + tenant_id=tenant_id, + request=dialogue_request, + request_id=request_id, + trace=trace, + trace_logger=trace_logger, + timeout_governor=timeout_governor, + metrics_collector=metrics_collector, + default_kb_tool_runner=default_kb_tool_runner, + runtime_observer=runtime_observer, + interrupt_ctx=interrupt_ctx, + start_time=start_time, + session=session, + tool_registry=tool_registry, + ) + elif router_result.mode == ExecutionMode.MICRO_FLOW: + response = await _execute_micro_flow_mode( + tenant_id=tenant_id, + request=dialogue_request, + router_result=router_result, + trace=trace, + trace_logger=trace_logger, + session=session, + start_time=start_time, + ) + elif router_result.mode == ExecutionMode.FIXED: + response = await _execute_fixed_mode( + tenant_id=tenant_id, + request=dialogue_request, + router_result=router_result, + trace=trace, + trace_logger=trace_logger, + start_time=start_time, + ) + else: + response = await _execute_transfer_mode( + tenant_id=tenant_id, + request=dialogue_request, + router_result=router_result, + trace=trace, + trace_logger=trace_logger, + start_time=start_time, + ) + + filtered_segments, guardrail_result = await output_guardrail_executor.filter_segments( + response.segments, tenant_id + ) + runtime_observer.record_guardrail( + request_id, guardrail_result.triggered, guardrail_result.rule_id + ) + + humanize_config = None + if dialogue_request.humanize_config: + humanize_config = HumanizeConfig( + enabled=dialogue_request.humanize_config.enabled or True, + min_delay_ms=dialogue_request.humanize_config.min_delay_ms or 50, + max_delay_ms=dialogue_request.humanize_config.max_delay_ms or 500, + length_bucket_strategy=dialogue_request.humanize_config.length_bucket_strategy or "simple", + ) + + final_segments, segment_stats = segment_humanizer.humanize( + "\n\n".join(s.text for s in filtered_segments), + humanize_config, + ) + runtime_observer.record_segment_stats(request_id, segment_stats) + + final_trace = runtime_observer.end_observation(request_id) + final_trace.segment_stats = segment_stats + final_trace.guardrail_triggered = guardrail_result.triggered + final_trace.guardrail_rule_id = guardrail_result.rule_id + + latency_ms = int((time.time() - start_time) * 1000) + + metrics_collector.record_turn( + session_id=dialogue_request.session_id, + tenant_id=tenant_id, + latency_ms=latency_ms, + task_completed=True, + ) + + audit = trace_logger.end_trace( + request_id=request_id, + tenant_id=tenant_id, + session_id=dialogue_request.session_id, + latency_ms=latency_ms, + ) + + logger.info( + f"[AC-MARH-12] Audit record: request_id={request_id}, " + f"mode={final_trace.mode.value}, latency_ms={latency_ms}, " + f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}" + ) + + return DialogueResponse( + segments=final_segments, + trace=final_trace, + ) + + except Exception as e: + latency_ms = int((time.time() - start_time) * 1000) + logger.error(f"[AC-IDMP-06] Dialogue error: {e}") + + trace_logger.update_trace( + request_id=request_id, + mode=ExecutionMode.FIXED, + fallback_reason_code=f"error: {str(e)[:50]}", + ) + + trace_logger.end_trace( + request_id=request_id, + tenant_id=tenant_id, + session_id=dialogue_request.session_id, + latency_ms=latency_ms, + ) + + return DialogueResponse( + segments=[Segment( + text="抱歉,服务暂时不可用,请稍后重试或联系人工客服。", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.FIXED, + request_id=request_id, + generation_id=generation_id, + fallback_reason_code="service_error", + ), + ) + + +async def _match_intent( + tenant_id: str, + request: DialogueRequest, + session: AsyncSession, +) -> IntentMatch | None: + """Match intent from user message.""" + try: + from app.services.intent.router import IntentRouter + from app.services.intent.rule_service import IntentRuleService + + rule_service = IntentRuleService(session) + rules = await rule_service.get_enabled_rules_for_matching(tenant_id) + + if not rules: + return None + + router = IntentRouter() + result = router.match(request.user_message, rules) + + if result: + return IntentMatch( + intent_id=str(result.rule.id), + intent_name=result.rule.name, + confidence=0.8, + response_type=result.rule.response_type, + target_kb_ids=result.rule.target_kb_ids, + flow_id=str(result.rule.flow_id) if result.rule.flow_id else None, + fixed_reply=result.rule.fixed_reply, + transfer_message=result.rule.transfer_message, + ) + + return None + + except Exception as e: + logger.warning(f"[AC-IDMP-02] Intent match failed: {e}") + return None + + +async def _handle_legacy_response( + tenant_id: str, + request: DialogueRequest, + trace: TraceInfo, + trace_logger: TraceLogger, + start_time: float, +) -> DialogueResponse: + """Handle rollback to legacy pipeline.""" + latency_ms = int((time.time() - start_time) * 1000) + + return DialogueResponse( + segments=[Segment( + text="正在使用传统模式处理您的请求...", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.FIXED, + request_id=trace.request_id, + generation_id=trace.generation_id, + fallback_reason_code="rollback_to_legacy", + ), + ) + + +async def _handle_high_risk_check_response( + tenant_id: str, + request: DialogueRequest, + high_risk_result: Any, + trace: TraceInfo, + trace_logger: TraceLogger, + start_time: float, + session: AsyncSession, +) -> DialogueResponse: + """ + [AC-IDMP-05, AC-IDMP-20] Handle high-risk scenario from high_risk_check tool. + + 高风险优先于普通意图路由。 + """ + from app.models.mid.schemas import HighRiskCheckResult + + if not isinstance(high_risk_result, HighRiskCheckResult): + high_risk_result = HighRiskCheckResult(**high_risk_result) + + latency_ms = int((time.time() - start_time) * 1000) + + recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW + risk_scenario = high_risk_result.risk_scenario + + trace_logger.update_trace( + request_id=trace.request_id or "", + mode=recommended_mode, + fallback_reason_code=f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}", + ) + + if recommended_mode == ExecutionMode.TRANSFER: + transfer_msg = "正在为您转接人工客服..." + if risk_scenario: + if risk_scenario.value == "complaint_escalation": + transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..." + elif risk_scenario.value == "refund": + transfer_msg = "您的退款请求需要人工处理,正在为您转接..." + + return DialogueResponse( + segments=[Segment( + text=transfer_msg, + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.TRANSFER, + request_id=trace.request_id, + generation_id=trace.generation_id, + high_risk_policy_set=[risk_scenario] if risk_scenario else None, + fallback_reason_code=high_risk_result.rule_id, + ), + ) + + if high_risk_result.rule_id: + try: + from sqlalchemy import select + from app.models.entities import HighRiskPolicy + import uuid + + stmt = select(HighRiskPolicy).where( + HighRiskPolicy.id == uuid.UUID(high_risk_result.rule_id) + ) + result = await session.execute(stmt) + policy = result.scalar_one_or_none() + + if policy and policy.flow_id: + return DialogueResponse( + segments=[Segment( + text="检测到您的请求需要特殊处理,正在为您安排...", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.MICRO_FLOW, + request_id=trace.request_id, + generation_id=trace.generation_id, + high_risk_policy_set=[risk_scenario] if risk_scenario else None, + fallback_reason_code=high_risk_result.rule_id, + ), + ) + except Exception as e: + logger.warning(f"[AC-IDMP-05] Failed to load high risk policy: {e}") + + return DialogueResponse( + segments=[Segment( + text="检测到您的请求需要特殊处理,正在为您安排...", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.MICRO_FLOW, + request_id=trace.request_id, + generation_id=trace.generation_id, + high_risk_policy_set=[risk_scenario] if risk_scenario else None, + fallback_reason_code=high_risk_result.rule_id or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}", + ), + ) + + +async def _handle_high_risk_response( + tenant_id: str, + request: DialogueRequest, + high_risk_match: Any, + high_risk_handler: HighRiskHandler, + trace: TraceInfo, + trace_logger: TraceLogger, + start_time: float, +) -> DialogueResponse: + """Handle high-risk scenario response.""" + router_result = high_risk_handler.handle(high_risk_match) + + latency_ms = int((time.time() - start_time) * 1000) + + trace_logger.update_trace( + request_id=trace.request_id or "", + mode=router_result.mode, + fallback_reason_code=router_result.fallback_reason_code, + ) + + if router_result.mode == ExecutionMode.TRANSFER: + return DialogueResponse( + segments=[Segment( + text=router_result.transfer_message or "正在为您转接人工客服...", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.TRANSFER, + request_id=trace.request_id, + generation_id=trace.generation_id, + high_risk_policy_set=[high_risk_match.scenario], + fallback_reason_code=router_result.fallback_reason_code, + ), + ) + + return DialogueResponse( + segments=[Segment( + text="检测到您的请求需要特殊处理,正在为您安排...", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.MICRO_FLOW, + request_id=trace.request_id, + generation_id=trace.generation_id, + high_risk_policy_set=[high_risk_match.scenario], + fallback_reason_code=router_result.fallback_reason_code, + ), + ) + + +async def _execute_agent_mode( + tenant_id: str, + request: DialogueRequest, + request_id: str, + trace: TraceInfo, + trace_logger: TraceLogger, + timeout_governor: TimeoutGovernor, + metrics_collector: MetricsCollector, + default_kb_tool_runner: DefaultKbToolRunner, + runtime_observer: RuntimeObserver, + interrupt_ctx: Any = None, + start_time: float = 0, + session: AsyncSession | None = None, + tool_registry: ToolRegistry | None = None, +) -> DialogueResponse: + """[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall.""" + from app.services.llm.factory import get_llm_config_manager + + try: + llm_manager = get_llm_config_manager() + llm_client = llm_manager.get_client() + except Exception as e: + logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}") + llm_client = None + + base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {} + + if interrupt_ctx and interrupt_ctx.consumed: + base_context["interrupted_content"] = interrupt_ctx.interrupted_content + base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids + logger.info( + f"[AC-MARH-03] Agent context enriched with interrupt: " + f"{len(interrupt_ctx.interrupted_content or '')} chars" + ) + + memory_context = "" + memory_missing_slots: list[str] = [] + if session and request.user_id: + memory_recall_tool = MemoryRecallTool( + session=session, + timeout_governor=timeout_governor, + config=MemoryRecallConfig( + enabled=True, + timeout_ms=1000, + max_recent_messages=8, + ), + ) + + memory_result = await memory_recall_tool.execute( + tenant_id=tenant_id, + user_id=request.user_id, + session_id=request.session_id, + ) + + if memory_result.duration_ms > 0: + memory_trace = memory_recall_tool.create_trace(memory_result, tenant_id) + trace_logger.update_trace( + request_id=request_id, + tool_calls=[memory_trace], + ) + + memory_context = memory_result.get_context_for_prompt() + memory_missing_slots = memory_result.missing_slots + + if memory_context: + base_context["memory_context"] = memory_context + logger.info( + f"[AC-IDMP-13] Memory recall succeeded: " + f"profile={len(memory_result.profile)}, facts={len(memory_result.facts)}, " + f"slots={len(memory_result.slots)}, missing_slots={len(memory_missing_slots)}, " + f"duration_ms={memory_result.duration_ms}" + ) + elif memory_result.fallback_reason_code: + logger.warning( + f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}" + ) + + kb_hits = [] + kb_success = False + kb_fallback_reason = None + kb_applied_filter = {} + kb_missing_slots = [] + + if session and tool_registry: + kb_tool = KbSearchDynamicTool( + session=session, + timeout_governor=timeout_governor, + config=KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=2000, + min_score_threshold=0.5, + ), + ) + + kb_dynamic_result = await kb_tool.execute( + query=request.user_message, + tenant_id=tenant_id, + scene="open_consult", + top_k=5, + context=base_context, + ) + + kb_success = kb_dynamic_result.success + kb_hits = kb_dynamic_result.hits + kb_fallback_reason = kb_dynamic_result.fallback_reason_code + kb_applied_filter = kb_dynamic_result.applied_filter + kb_missing_slots = kb_dynamic_result.missing_required_slots + + if kb_dynamic_result.tool_trace: + trace_logger.update_trace( + request_id=request_id, + tool_calls=[kb_dynamic_result.tool_trace], + ) + + logger.info( + f"[AC-MARH-05] KB dynamic search: success={kb_success}, " + f"hits={len(kb_hits)}, filter={kb_applied_filter}, " + f"missing_slots={kb_missing_slots}" + ) + else: + kb_result = await default_kb_tool_runner.execute( + tenant_id=tenant_id, + query=request.user_message, + ) + kb_success = kb_result.success + kb_hits = kb_result.hits + kb_fallback_reason = kb_result.fallback_reason_code + + runtime_observer.record_kb( + request_id, + tool_called=True, + hit=kb_success and len(kb_hits) > 0, + fallback_reason=kb_fallback_reason, + ) + + if kb_success and kb_hits: + kb_context = "\n".join([ + f"[知识库] {hit.get('content', '')[:200]}" + for hit in kb_hits[:3] + ]) + base_context["kb_context"] = kb_context + logger.info( + f"[AC-MARH-05] KB retrieval succeeded: hits={len(kb_hits)}" + ) + elif kb_fallback_reason: + logger.warning( + f"[AC-MARH-06] KB retrieval fallback: reason={kb_fallback_reason}" + ) + + orchestrator = AgentOrchestrator( + max_iterations=5, + timeout_governor=timeout_governor, + llm_client=llm_client, + tool_registry=tool_registry, + ) + + final_answer, react_ctx, agent_trace = await orchestrator.execute( + user_message=request.user_message, + context=base_context, + ) + + runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls) + + trace_logger.update_trace( + request_id=request_id, + react_iterations=react_ctx.iteration, + tool_calls=react_ctx.tool_calls, + ) + + segments = _text_to_segments(final_answer) + + return DialogueResponse( + segments=segments, + trace=TraceInfo( + mode=ExecutionMode.AGENT, + request_id=trace.request_id, + generation_id=trace.generation_id, + react_iterations=react_ctx.iteration, + tools_used=[tc.tool_name for tc in react_ctx.tool_calls] if react_ctx.tool_calls else None, + tool_calls=react_ctx.tool_calls, + timeout_profile=timeout_governor.profile, + kb_tool_called=True, + kb_hit=kb_success and len(kb_hits) > 0, + fallback_reason_code=kb_fallback_reason, + ), + ) + + +async def _execute_micro_flow_mode( + tenant_id: str, + request: DialogueRequest, + router_result: Any, + trace: TraceInfo, + trace_logger: TraceLogger, + session: AsyncSession, + start_time: float, +) -> DialogueResponse: + """Execute micro flow mode.""" + if router_result.target_flow_id: + try: + from app.services.flow.engine import FlowEngine + + flow_engine = FlowEngine(session) + instance, first_step = await flow_engine.start( + tenant_id=tenant_id, + session_id=request.session_id, + flow_id=router_result.target_flow_id, + ) + + if first_step: + return DialogueResponse( + segments=_text_to_segments(first_step), + trace=TraceInfo( + mode=ExecutionMode.MICRO_FLOW, + request_id=trace.request_id, + generation_id=trace.generation_id, + intent=router_result.intent, + ), + ) + except Exception as e: + logger.warning(f"[AC-IDMP-05] Micro flow start failed: {e}") + + return DialogueResponse( + segments=[Segment( + text="正在为您处理,请稍候...", + delay_after=0, + )], + trace=TraceInfo( + mode=ExecutionMode.MICRO_FLOW, + request_id=trace.request_id, + generation_id=trace.generation_id, + intent=router_result.intent, + fallback_reason_code=router_result.fallback_reason_code, + ), + ) + + +async def _execute_fixed_mode( + tenant_id: str, + request: DialogueRequest, + router_result: Any, + trace: TraceInfo, + trace_logger: TraceLogger, + start_time: float, +) -> DialogueResponse: + """Execute fixed reply mode.""" + text = router_result.fixed_reply or "收到您的消息,我们会尽快处理。" + + return DialogueResponse( + segments=_text_to_segments(text), + trace=TraceInfo( + mode=ExecutionMode.FIXED, + request_id=trace.request_id, + generation_id=trace.generation_id, + intent=router_result.intent, + fallback_reason_code=router_result.fallback_reason_code, + ), + ) + + +async def _execute_transfer_mode( + tenant_id: str, + request: DialogueRequest, + router_result: Any, + trace: TraceInfo, + trace_logger: TraceLogger, + start_time: float, +) -> DialogueResponse: + """Execute transfer to human mode.""" + text = router_result.transfer_message or "正在为您转接人工客服,请稍候..." + + return DialogueResponse( + segments=_text_to_segments(text), + trace=TraceInfo( + mode=ExecutionMode.TRANSFER, + request_id=trace.request_id, + generation_id=trace.generation_id, + intent=router_result.intent, + fallback_reason_code=router_result.fallback_reason_code, + ), + ) + + +def _text_to_segments(text: str) -> list[Segment]: + """Convert text to segments.""" + paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] + + if not paragraphs: + paragraphs = [text] + + return [ + Segment(text=p, delay_after=100 if i < len(paragraphs) - 1 else 0) + for i, p in enumerate(paragraphs) + ] diff --git a/ai-service/app/api/mid/messages.py b/ai-service/app/api/mid/messages.py new file mode 100644 index 0000000..9c1d2bf --- /dev/null +++ b/ai-service/app/api/mid/messages.py @@ -0,0 +1,104 @@ +""" +Messages Controller for Mid Platform. +[AC-IDMP-08] Message report endpoint: POST /mid/messages/report +""" + +import logging +import time +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.core.tenant import get_tenant_id +from app.models.mid.schemas import MessageReportRequest +from app.services.memory import MemoryService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/mid", tags=["Mid Platform Messages"]) + + +@router.post( + "/messages/report", + operation_id="reportMessages", + summary="Report session messages and events", + description=""" + [AC-IDMP-08] Report messages from channel to mid platform. + + Accepts messages from bot, human, or channel sources for session closure. + Returns 202 Accepted for async processing. + """, + responses={ + 202: {"description": "Accepted for async processing"}, + 400: {"description": "Invalid request"}, + }, +) +async def report_messages( + request: Request, + report_request: MessageReportRequest, + session: Annotated[AsyncSession, Depends(get_session)], +) -> JSONResponse: + """ + [AC-IDMP-08] Report messages from channel. + + Accepts and stores messages for session data completeness. + Messages are stored asynchronously without blocking response. + """ + tenant_id = get_tenant_id() + + if not tenant_id: + from app.core.exceptions import MissingTenantIdException + raise MissingTenantIdException() + + logger.info( + f"[AC-IDMP-08] Message report: tenant={tenant_id}, " + f"session={report_request.session_id}, count={len(report_request.messages)}" + ) + + try: + memory_service = MemoryService(session) + + await memory_service.get_or_create_session( + tenant_id=tenant_id, + session_id=report_request.session_id, + ) + + messages_to_save = [] + for msg in report_request.messages: + role = msg.role + if msg.source == "human": + role = "human" + elif msg.source == "bot": + role = "assistant" + + messages_to_save.append({ + "role": role, + "content": msg.content, + }) + + if messages_to_save: + await memory_service.append_messages( + tenant_id=tenant_id, + session_id=report_request.session_id, + messages=messages_to_save, + ) + + logger.info( + f"[AC-IDMP-08] Messages saved: tenant={tenant_id}, " + f"session={report_request.session_id}, count={len(messages_to_save)}" + ) + + return JSONResponse( + status_code=202, + content={"status": "accepted", "session_id": report_request.session_id}, + ) + + except Exception as e: + logger.error(f"[AC-IDMP-08] Message report failed: {e}") + return JSONResponse( + status_code=202, + content={"status": "accepted", "warning": str(e)}, + ) diff --git a/ai-service/app/api/mid/sessions.py b/ai-service/app/api/mid/sessions.py new file mode 100644 index 0000000..c2d2007 --- /dev/null +++ b/ai-service/app/api/mid/sessions.py @@ -0,0 +1,105 @@ +""" +Sessions Controller for Mid Platform. +[AC-IDMP-09] Session mode switch endpoint: POST /mid/sessions/{sessionId}/mode +""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, Path +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.core.tenant import get_tenant_id +from app.models.mid.schemas import ( + SessionMode, + SwitchModeRequest, + SwitchModeResponse, +) +from app.services.memory import MemoryService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"]) + +_session_modes: dict[str, SessionMode] = {} + + +@router.post( + "/sessions/{sessionId}/mode", + operation_id="switchSessionMode", + summary="Switch session mode", + description=""" + [AC-IDMP-09] Switch session mode between BOT_ACTIVE and HUMAN_ACTIVE. + + When mode is HUMAN_ACTIVE, dialogue responses will route to transfer mode. + """, + responses={ + 200: {"description": "Mode switched successfully", "model": SwitchModeResponse}, + 400: {"description": "Invalid request"}, + }, +) +async def switch_session_mode( + sessionId: Annotated[str, Path(description="Session ID")], + switch_request: SwitchModeRequest, + session: Annotated[AsyncSession, Depends(get_session)], +) -> SwitchModeResponse: + """ + [AC-IDMP-09] Switch session mode. + + Modes: + - BOT_ACTIVE: Bot handles responses + - HUMAN_ACTIVE: Transfer to human agent + """ + tenant_id = get_tenant_id() + + if not tenant_id: + from app.core.exceptions import MissingTenantIdException + raise MissingTenantIdException() + + logger.info( + f"[AC-IDMP-09] Mode switch: tenant={tenant_id}, " + f"session={sessionId}, mode={switch_request.mode.value}" + ) + + try: + memory_service = MemoryService(session) + + await memory_service.get_or_create_session( + tenant_id=tenant_id, + session_id=sessionId, + ) + + session_key = f"{tenant_id}:{sessionId}" + _session_modes[session_key] = switch_request.mode + + logger.info( + f"[AC-IDMP-09] Mode switched: session={sessionId}, " + f"mode={switch_request.mode.value}, reason={switch_request.reason}" + ) + + return SwitchModeResponse( + session_id=sessionId, + mode=switch_request.mode, + ) + + except Exception as e: + logger.error(f"[AC-IDMP-09] Mode switch failed: {e}") + return SwitchModeResponse( + session_id=sessionId, + mode=switch_request.mode, + ) + + +def get_session_mode(tenant_id: str, session_id: str) -> SessionMode: + """Get current session mode.""" + session_key = f"{tenant_id}:{session_id}" + return _session_modes.get(session_key, SessionMode.BOT_ACTIVE) + + +def clear_session_mode(tenant_id: str, session_id: str) -> None: + """Clear session mode (reset to BOT_ACTIVE).""" + session_key = f"{tenant_id}:{session_id}" + if session_key in _session_modes: + del _session_modes[session_key] diff --git a/ai-service/app/models/mid/__init__.py b/ai-service/app/models/mid/__init__.py new file mode 100644 index 0000000..642077a --- /dev/null +++ b/ai-service/app/models/mid/__init__.py @@ -0,0 +1,224 @@ +""" +Mid Platform models for Intent-Driven Agent. +[AC-IDMP-01~20] 中台统一响应协议模型 +""" + +import uuid +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class Mode(str, Enum): + """[AC-IDMP-02] 执行模式""" + AGENT = "agent" + MICRO_FLOW = "micro_flow" + FIXED = "fixed" + TRANSFER = "transfer" + + +class SessionMode(str, Enum): + """[AC-IDMP-09] 会话模式""" + BOT_ACTIVE = "BOT_ACTIVE" + HUMAN_ACTIVE = "HUMAN_ACTIVE" + + +class HighRiskScenario(str, Enum): + """[AC-IDMP-20] 高风险场景最小集""" + REFUND = "refund" + COMPLAINT_ESCALATION = "complaint_escalation" + PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise" + TRANSFER = "transfer" + + +class ToolCallStatus(str, Enum): + """[AC-IDMP-15] 工具调用状态""" + OK = "ok" + TIMEOUT = "timeout" + ERROR = "error" + REJECTED = "rejected" + + +class ToolType(str, Enum): + """[AC-IDMP-19] 工具类型""" + INTERNAL = "internal" + MCP = "mcp" + + +class HistoryMessage(BaseModel): + """[AC-IDMP-03] 已送达历史消息""" + role: str = Field(..., description="消息角色: user/assistant/human") + content: str = Field(..., description="消息内容") + + +class InterruptedSegment(BaseModel): + """[AC-IDMP-04] 打断的分段""" + segment_id: str = Field(..., description="分段ID") + content: str = Field(..., description="分段内容") + + +class FeatureFlags(BaseModel): + """[AC-IDMP-17] 特性开关""" + agent_enabled: bool | None = Field(default=None, description="会话级 Agent 灰度开关") + rollback_to_legacy: bool | None = Field(default=None, description="强制回滚传统链路") + + +class DialogueRequest(BaseModel): + """[AC-IDMP-01~04] 会话响应请求""" + session_id: str = Field(..., description="会话ID") + user_id: str | None = Field(default=None, description="用户ID,用于记忆召回与更新") + user_message: str = Field(..., min_length=1, max_length=2000, description="用户消息") + history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史") + interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段") + feature_flags: FeatureFlags | None = Field(default=None, description="特性开关") + + +class Segment(BaseModel): + """[AC-IDMP-01] 响应分段""" + segment_id: str = Field(..., description="分段ID") + text: str = Field(..., description="分段文本") + delay_after: int = Field(default=0, ge=0, description="分段后延迟(ms)") + + +class TimeoutProfile(BaseModel): + """[AC-IDMP-12] 超时配置""" + per_tool_timeout_ms: int | None = Field(default=30000, le=60000, description="单工具超时(ms)") + end_to_end_timeout_ms: int | None = Field(default=120000, le=180000, description="端到端超时(ms)") + + +class MetricsSnapshot(BaseModel): + """[AC-IDMP-18] 运行指标快照""" + task_completion_rate: float | None = Field(default=None, ge=0, le=1, description="任务达成率") + slot_completion_rate: float | None = Field(default=None, ge=0, le=1, description="槽位完整率") + wrong_transfer_rate: float | None = Field(default=None, ge=0, le=1, description="误转人工率") + no_recall_rate: float | None = Field(default=None, ge=0, le=1, description="无召回率") + avg_latency_ms: float | None = Field(default=None, ge=0, description="平均时延(ms)") + + +class ToolCallTraceModel(BaseModel): + """[AC-IDMP-15/19] 工具调用追踪 (Pydantic 模型)""" + tool_name: str = Field(..., description="工具名称") + tool_type: ToolType | None = Field(default=None, description="工具类型: internal/mcp") + registry_version: str | None = Field(default=None, description="注册版本") + auth_applied: bool | None = Field(default=None, description="是否应用鉴权") + duration_ms: int = Field(..., ge=0, description="耗时(ms)") + status: ToolCallStatus = Field(..., description="状态: ok/timeout/error/rejected") + error_code: str | None = Field(default=None, description="错误码") + args_digest: str | None = Field(default=None, description="参数摘要") + result_digest: str | None = Field(default=None, description="结果摘要") + + +class TraceInfo(BaseModel): + """[AC-IDMP-02/07] 追踪信息""" + mode: Mode = Field(..., description="执行模式") + intent: str | None = Field(default=None, description="意图") + request_id: str | None = Field(default=None, description="请求ID") + generation_id: str | None = Field(default=None, description="生成ID") + guardrail_triggered: bool | None = Field(default=False, description="护栏是否触发") + fallback_reason_code: str | None = Field(default=None, description="降级原因码") + react_iterations: int | None = Field(default=None, ge=0, le=5, description="ReAct循环次数") + timeout_profile: TimeoutProfile | None = Field(default=None, description="超时配置") + metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="指标快照") + high_risk_policy_set: list[HighRiskScenario] | None = Field( + default=None, + description="当前启用的高风险最小场景集" + ) + tools_used: list[str] | None = Field(default=None, description="使用的工具列表") + tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪") + + +class DialogueResponse(BaseModel): + """[AC-IDMP-01/02] 会话响应""" + segments: list[Segment] = Field(..., description="响应分段列表") + trace: TraceInfo = Field(..., description="追踪信息") + + +class ReportedMessage(BaseModel): + """[AC-IDMP-08] 上报的消息""" + role: str = Field(..., description="角色: user/assistant/human/system") + content: str = Field(..., description="消息内容") + source: str = Field(..., description="来源: bot/human/channel") + timestamp: datetime = Field(..., description="时间戳") + segment_id: str | None = Field(default=None, description="分段ID") + + +class MessageReportRequest(BaseModel): + """[AC-IDMP-08] 消息上报请求""" + session_id: str = Field(..., description="会话ID") + messages: list[ReportedMessage] = Field(..., description="消息列表") + + +class SwitchModeRequest(BaseModel): + """[AC-IDMP-09] 切换模式请求""" + mode: SessionMode = Field(..., description="目标模式") + reason: str | None = Field(default=None, description="切换原因") + + +class SwitchModeResponse(BaseModel): + """[AC-IDMP-09] 切换模式响应""" + session_id: str = Field(..., description="会话ID") + mode: SessionMode = Field(..., description="当前模式") + + +from app.models.mid.memory import ( + RecallRequest, + RecallResponse, + UpdateRequest, + MemoryProfile, + MemoryFact, + MemoryPreferences, +) +from app.models.mid.tool_trace import ( + ToolCallTrace, + ToolCallBuilder, + ToolCallStatus as ToolCallStatusEnum, + ToolType as ToolTypeEnum, +) +from app.models.mid.tool_registry import ( + ToolDefinition, + ToolAuthConfig, + ToolTimeoutPolicy, + ToolRegistryEntity, + ToolRegistryCreate, + ToolRegistryUpdate, +) + +__all__ = [ + "Mode", + "SessionMode", + "HighRiskScenario", + "ToolCallStatus", + "ToolType", + "HistoryMessage", + "InterruptedSegment", + "FeatureFlags", + "DialogueRequest", + "Segment", + "TimeoutProfile", + "MetricsSnapshot", + "ToolCallTraceModel", + "TraceInfo", + "DialogueResponse", + "ReportedMessage", + "MessageReportRequest", + "SwitchModeRequest", + "SwitchModeResponse", + "RecallRequest", + "RecallResponse", + "UpdateRequest", + "MemoryProfile", + "MemoryFact", + "MemoryPreferences", + "ToolCallTrace", + "ToolCallBuilder", + "ToolCallStatusEnum", + "ToolTypeEnum", + "ToolDefinition", + "ToolAuthConfig", + "ToolTimeoutPolicy", + "ToolRegistryEntity", + "ToolRegistryCreate", + "ToolRegistryUpdate", +] diff --git a/ai-service/app/models/mid/memory.py b/ai-service/app/models/mid/memory.py new file mode 100644 index 0000000..da50c3a --- /dev/null +++ b/ai-service/app/models/mid/memory.py @@ -0,0 +1,182 @@ +""" +Memory models for Mid Platform. +[AC-IDMP-13] 记忆召回数据模型 +[AC-IDMP-14] 记忆更新数据模型 + +Reference: spec/intent-driven-mid-platform/openapi.deps.yaml +""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any +from uuid import UUID + + +@dataclass +class MemoryProfile: + """ + [AC-IDMP-13] 用户基础属性记忆 + 包含年级、地区、渠道等基础信息 + """ + grade: str | None = None + region: str | None = None + channel: str | None = None + vip_level: str | None = None + registration_date: datetime | None = None + extra: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + result = {} + if self.grade: + result["grade"] = self.grade + if self.region: + result["region"] = self.region + if self.channel: + result["channel"] = self.channel + if self.vip_level: + result["vip_level"] = self.vip_level + if self.registration_date: + result["registration_date"] = self.registration_date.isoformat() + if self.extra: + result.update(self.extra) + return result + + +@dataclass +class MemoryFact: + """ + [AC-IDMP-13] 事实型记忆 + 包含已购课程、学习结论等客观事实 + """ + content: str + source: str | None = None + confidence: float | None = None + created_at: datetime | None = None + expires_at: datetime | None = None + + def to_string(self) -> str: + return self.content + + +@dataclass +class MemoryPreferences: + """ + [AC-IDMP-13] 偏好记忆 + 包含语气偏好、关注科目等用户偏好 + """ + tone: str | None = None + focus_subjects: list[str] = field(default_factory=list) + communication_style: str | None = None + preferred_time: str | None = None + extra: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + result = {} + if self.tone: + result["tone"] = self.tone + if self.focus_subjects: + result["focus_subjects"] = self.focus_subjects + if self.communication_style: + result["communication_style"] = self.communication_style + if self.preferred_time: + result["preferred_time"] = self.preferred_time + if self.extra: + result.update(self.extra) + return result + + +@dataclass +class RecallRequest: + """ + [AC-IDMP-13] 记忆召回请求 + Reference: openapi.deps.yaml - RecallRequest + """ + user_id: str + session_id: str + + def to_dict(self) -> dict[str, Any]: + return { + "user_id": self.user_id, + "session_id": self.session_id, + } + + +@dataclass +class RecallResponse: + """ + [AC-IDMP-13] 记忆召回响应 + Reference: openapi.deps.yaml - RecallResponse + """ + profile: MemoryProfile | None = None + facts: list[MemoryFact] = field(default_factory=list) + preferences: MemoryPreferences | None = None + last_summary: str | None = None + + def to_dict(self) -> dict[str, Any]: + result = {} + if self.profile: + result["profile"] = self.profile.to_dict() + if self.facts: + result["facts"] = [f.to_string() for f in self.facts] + if self.preferences: + result["preferences"] = self.preferences.to_dict() + if self.last_summary: + result["last_summary"] = self.last_summary + return result + + def get_context_for_prompt(self) -> str: + """ + 生成用于注入 Prompt 的上下文字符串 + """ + parts = [] + + if self.profile: + profile_parts = [] + if self.profile.grade: + profile_parts.append(f"年级: {self.profile.grade}") + if self.profile.region: + profile_parts.append(f"地区: {self.profile.region}") + if self.profile.vip_level: + profile_parts.append(f"会员等级: {self.profile.vip_level}") + if profile_parts: + parts.append("【用户属性】" + "、".join(profile_parts)) + + if self.facts: + fact_strs = [f.content for f in self.facts[:5]] + parts.append("【已知事实】" + ";".join(fact_strs)) + + if self.preferences: + pref_parts = [] + if self.preferences.tone: + pref_parts.append(f"语气偏好: {self.preferences.tone}") + if self.preferences.focus_subjects: + pref_parts.append(f"关注科目: {', '.join(self.preferences.focus_subjects)}") + if pref_parts: + parts.append("【用户偏好】" + "、".join(pref_parts)) + + if self.last_summary: + parts.append(f"【上次会话摘要】{self.last_summary}") + + return "\n".join(parts) if parts else "" + + +@dataclass +class UpdateRequest: + """ + [AC-IDMP-14] 记忆更新请求 + Reference: openapi.deps.yaml - UpdateRequest + """ + user_id: str + session_id: str + messages: list[dict[str, Any]] + summary: str | None = None + + def to_dict(self) -> dict[str, Any]: + result = { + "user_id": self.user_id, + "session_id": self.session_id, + "messages": self.messages, + } + if self.summary: + result["summary"] = self.summary + return result diff --git a/ai-service/app/models/mid/schemas.py b/ai-service/app/models/mid/schemas.py new file mode 100644 index 0000000..beb4e41 --- /dev/null +++ b/ai-service/app/models/mid/schemas.py @@ -0,0 +1,407 @@ +""" +Mid Platform schemas. +[AC-IDMP-01, AC-IDMP-02, AC-IDMP-07, AC-IDMP-11, AC-IDMP-12, AC-IDMP-15, AC-IDMP-17, AC-IDMP-18, AC-IDMP-19, AC-IDMP-20] +Aligned with spec/intent-driven-mid-platform/openapi.provider.yaml +""" + +from __future__ import annotations + +import uuid +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class ExecutionMode(str, Enum): + """[AC-IDMP-02] Execution mode for dialogue response.""" + AGENT = "agent" + MICRO_FLOW = "micro_flow" + FIXED = "fixed" + TRANSFER = "transfer" + + +class HighRiskScenario(str, Enum): + """[AC-IDMP-20] High risk scenario types for mandatory takeover.""" + REFUND = "refund" + COMPLAINT_ESCALATION = "complaint_escalation" + PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise" + TRANSFER = "transfer" + + +class ToolCallStatus(str, Enum): + """[AC-IDMP-15] Tool call status.""" + OK = "ok" + TIMEOUT = "timeout" + ERROR = "error" + REJECTED = "rejected" + + +class ToolType(str, Enum): + """[AC-IDMP-19] Tool type for registry governance.""" + INTERNAL = "internal" + MCP = "mcp" + + +class SessionMode(str, Enum): + """[AC-IDMP-09] Session mode for bot/human switching.""" + BOT_ACTIVE = "BOT_ACTIVE" + HUMAN_ACTIVE = "HUMAN_ACTIVE" + + +class HistoryMessage(BaseModel): + """[AC-IDMP-03] History message with only delivered content.""" + role: str = Field(..., description="Message role: user, assistant, or human") + content: str = Field(..., description="Message content") + + +class InterruptedSegment(BaseModel): + """[AC-IDMP-04] Interrupted segment for handling user interruption.""" + segment_id: str = Field(..., description="Segment ID") + content: str = Field(..., description="Segment content") + + +class FeatureFlags(BaseModel): + """[AC-IDMP-17] Feature flags for session-level grayscale and rollback.""" + agent_enabled: bool | None = Field(default=True, description="Session-level Agent grayscale switch") + rollback_to_legacy: bool | None = Field(default=False, description="Force rollback to legacy pipeline") + + +class HumanizeConfigRequest(BaseModel): + """[AC-MARH-11] 拟人化配置请求。""" + enabled: bool | None = Field(default=True, description="Enable humanize strategy") + min_delay_ms: int | None = Field(default=50, ge=0, description="Minimum delay in milliseconds") + max_delay_ms: int | None = Field(default=500, ge=0, description="Maximum delay in milliseconds") + length_bucket_strategy: str | None = Field(default="simple", description="Strategy: simple or semantic") + + +class DialogueRequest(BaseModel): + """[AC-IDMP-01, AC-IDMP-03, AC-IDMP-04, AC-IDMP-17, AC-MARH-11] Dialogue request schema.""" + session_id: str = Field(..., description="Session ID for conversation tracking") + user_id: str | None = Field(default=None, description="User ID for memory recall and update") + user_message: str = Field(..., min_length=1, max_length=2000, description="User message content") + history: list[HistoryMessage] = Field(default_factory=list, description="Only delivered history") + interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="Interrupted segments") + feature_flags: FeatureFlags | None = Field(default=None, description="Feature flags for grayscale control") + humanize_config: HumanizeConfigRequest | None = Field( + default=None, description="Humanize config for segment delay" + ) + + +class Segment(BaseModel): + """[AC-IDMP-01] Response segment with delay control.""" + segment_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Segment ID") + text: str = Field(..., description="Segment text content") + delay_after: int = Field(default=0, ge=0, description="Delay after this segment in milliseconds") + + +class TimeoutProfile(BaseModel): + """[AC-MARH-08, AC-MARH-09] Timeout configuration profile.""" + per_tool_timeout_ms: int = Field(default=30000, le=60000, description="Per-tool timeout in milliseconds") + llm_timeout_ms: int = Field(default=60000, le=120000, description="LLM call timeout in milliseconds") + end_to_end_timeout_ms: int = Field(default=120000, le=180000, description="End-to-end timeout in milliseconds") + + +class MetricsSnapshot(BaseModel): + """[AC-IDMP-18] Runtime metrics snapshot.""" + task_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Task completion rate") + slot_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Slot completion rate") + wrong_transfer_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Wrong transfer rate") + no_recall_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="No recall rate") + avg_latency_ms: float | None = Field(default=None, ge=0.0, description="Average latency in milliseconds") + + +class ToolCallTrace(BaseModel): + """[AC-IDMP-15, AC-IDMP-19] Tool call trace for observability.""" + tool_name: str = Field(..., description="Tool name") + tool_type: ToolType | None = Field(default=ToolType.INTERNAL, description="Tool type: internal or mcp") + registry_version: str | None = Field(default=None, description="Tool registry version") + auth_applied: bool | None = Field(default=False, description="Whether auth was applied") + duration_ms: int = Field(..., ge=0, description="Duration in milliseconds") + status: ToolCallStatus = Field(..., description="Tool call status") + error_code: str | None = Field(default=None, description="Error code if failed") + args_digest: str | None = Field(default=None, description="Arguments digest for logging") + result_digest: str | None = Field(default=None, description="Result digest for logging") + + +class SegmentStats(BaseModel): + """[AC-MARH-12] Segment statistics for humanize strategy.""" + segment_count: int = Field(default=0, ge=0, description="Number of segments") + avg_segment_length: float = Field(default=0.0, ge=0.0, description="Average segment length") + humanize_strategy: str | None = Field(default=None, description="Humanize strategy used") + + +class TraceInfo(BaseModel): + """[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11, + AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability.""" + mode: ExecutionMode = Field(..., description="Execution mode") + intent: str | None = Field(default=None, description="Matched intent") + request_id: str | None = Field( + default_factory=lambda: str(uuid.uuid4()), description="Request ID" + ) + generation_id: str | None = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Generation ID for interrupt handling", + ) + guardrail_triggered: bool | None = Field(default=False, description="Whether guardrail was triggered") + guardrail_rule_id: str | None = Field(default=None, description="Guardrail rule ID that triggered") + interrupt_consumed: bool | None = Field(default=False, description="Whether interrupted segments were consumed") + kb_tool_called: bool | None = Field(default=False, description="Whether KB tool was called") + kb_hit: bool | None = Field(default=False, description="Whether KB search had results") + fallback_reason_code: str | None = Field(default=None, description="Fallback reason code") + react_iterations: int | None = Field(default=0, ge=0, le=5, description="ReAct loop iterations") + timeout_profile: TimeoutProfile | None = Field(default=None, description="Timeout profile") + segment_stats: SegmentStats | None = Field(default=None, description="Segment statistics") + metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="Metrics snapshot") + high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set") + tools_used: list[str] | None = Field(default=None, description="Tools used in this request") + tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces") + + +class DialogueResponse(BaseModel): + """[AC-IDMP-01, AC-IDMP-02] Dialogue response with segments and trace.""" + segments: list[Segment] = Field(..., description="Response segments") + trace: TraceInfo = Field(..., description="Trace info for observability") + + +class ReportedMessage(BaseModel): + """[AC-IDMP-08] Reported message for message report API.""" + role: str = Field(..., description="Message role: user, assistant, human, or system") + content: str = Field(..., description="Message content") + source: str = Field(..., description="Message source: bot, human, or channel") + timestamp: str = Field(..., description="Message timestamp in ISO format") + segment_id: str | None = Field(default=None, description="Segment ID if applicable") + + +class MessageReportRequest(BaseModel): + """[AC-IDMP-08] Message report request schema.""" + session_id: str = Field(..., description="Session ID") + messages: list[ReportedMessage] = Field(..., description="Messages to report") + + +class SwitchModeRequest(BaseModel): + """[AC-IDMP-09] Switch session mode request.""" + mode: SessionMode = Field(..., description="Target mode: BOT_ACTIVE or HUMAN_ACTIVE") + reason: str | None = Field(default=None, description="Reason for mode switch") + + +class SwitchModeResponse(BaseModel): + """[AC-IDMP-09] Switch session mode response.""" + session_id: str = Field(..., description="Session ID") + mode: SessionMode = Field(..., description="Current mode after switch") + + +class MidSessionState(BaseModel): + """Internal session state for mid platform.""" + session_id: str + tenant_id: str + mode: SessionMode = SessionMode.BOT_ACTIVE + generation_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + active_flow_id: str | None = None + context: dict[str, Any] | None = None + created_at: str | None = None + updated_at: str | None = None + + +class PolicyRouterResult(BaseModel): + """[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16] Policy router decision result.""" + mode: ExecutionMode = Field(..., description="Decided execution mode") + intent: str | None = Field(default=None, description="Matched intent") + confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Intent confidence") + fallback_reason_code: str | None = Field(default=None, description="Fallback reason if applicable") + high_risk_triggered: bool = Field(default=False, description="Whether high-risk scenario triggered") + target_flow_id: str | None = Field(default=None, description="Target flow ID for micro_flow mode") + fixed_reply: str | None = Field(default=None, description="Fixed reply for fixed mode") + transfer_message: str | None = Field(default=None, description="Transfer message for transfer mode") + + +class ReActContext(BaseModel): + """[AC-IDMP-11] ReAct loop context for iteration control.""" + iteration: int = Field(default=0, ge=0, le=5, description="Current iteration count") + max_iterations: int = Field(default=5, ge=3, le=5, description="Maximum iterations allowed") + tool_calls: list[ToolCallTrace] = Field(default_factory=list, description="Tool call history") + should_continue: bool = Field(default=True, description="Whether to continue ReAct loop") + final_answer: str | None = Field(default=None, description="Final answer if completed") + + +class CreateShareRequest(BaseModel): + """[AC-IDMP-SHARE] Request to create a shared session.""" + title: str | None = Field(default=None, max_length=255, description="Share title") + description: str | None = Field(default=None, max_length=1000, description="Share description") + expires_in_days: int = Field(default=7, ge=1, le=365, description="Expiration time in days") + max_concurrent_users: int = Field(default=10, ge=1, le=100, description="Maximum concurrent users") + + +class ShareResponse(BaseModel): + """[AC-IDMP-SHARE] Response after creating a share.""" + share_token: str = Field(..., description="Unique share token") + share_url: str = Field(..., description="Full share URL") + expires_at: str = Field(..., description="Expiration time in ISO format") + title: str | None = Field(default=None, description="Share title") + description: str | None = Field(default=None, description="Share description") + max_concurrent_users: int = Field(..., description="Maximum concurrent users") + + +class SharedSessionInfo(BaseModel): + """[AC-IDMP-SHARE] Information about a shared session.""" + session_id: str = Field(..., description="Session ID") + title: str | None = Field(default=None, description="Share title") + description: str | None = Field(default=None, description="Share description") + expires_at: str = Field(..., description="Expiration time in ISO format") + max_concurrent_users: int = Field(..., description="Maximum concurrent users") + current_users: int = Field(..., description="Current online users") + history: list[HistoryMessage] = Field(default_factory=list, description="Historical messages") + + +class SharedMessageRequest(BaseModel): + """[AC-IDMP-SHARE] Request to send a message via shared session.""" + user_message: str = Field(..., min_length=1, max_length=2000, description="User message content") + + +class ShareListItem(BaseModel): + """[AC-IDMP-SHARE] Share list item for listing all shares of a session.""" + share_token: str = Field(..., description="Share token") + share_url: str = Field(..., description="Full share URL") + title: str | None = Field(default=None, description="Share title") + description: str | None = Field(default=None, description="Share description") + expires_at: str = Field(..., description="Expiration time in ISO format") + is_active: bool = Field(..., description="Whether share is active") + max_concurrent_users: int = Field(..., description="Maximum concurrent users") + current_users: int = Field(..., description="Current online users") + created_at: str = Field(..., description="Creation time in ISO format") + + +class ShareListResponse(BaseModel): + """[AC-IDMP-SHARE] Response for listing shares.""" + shares: list[ShareListItem] = Field(..., description="List of shares") + + +class KbSearchDynamicHit(BaseModel): + """[AC-MARH-05] Single KB search hit.""" + id: str = Field(..., description="Hit ID") + content: str = Field(..., description="Hit content") + score: float = Field(..., ge=0.0, le=1.0, description="Relevance score") + metadata: dict[str, Any] = Field(default_factory=dict, description="Hit metadata") + + +class MissingRequiredSlot(BaseModel): + """[AC-MARH-05] Missing required slot info.""" + field_key: str = Field(..., description="Field key") + label: str = Field(..., description="Field label") + reason: str = Field(..., description="Missing reason") + + +class KbSearchDynamicResultSchema(BaseModel): + """[AC-MARH-05, AC-MARH-06] KB dynamic search result schema.""" + success: bool = Field(..., description="Whether search succeeded") + hits: list[KbSearchDynamicHit] = Field(default_factory=list, description="Search hits") + applied_filter: dict[str, Any] = Field(default_factory=dict, description="Applied filter") + missing_required_slots: list[MissingRequiredSlot] = Field( + default_factory=list, description="Missing required slots" + ) + filter_debug: dict[str, Any] = Field(default_factory=dict, description="Filter debug info") + fallback_reason_code: str | None = Field(default=None, description="Fallback reason code") + duration_ms: int = Field(default=0, ge=0, description="Duration in milliseconds") + + +class IntentHintOutput(BaseModel): + """[AC-IDMP-02, AC-IDMP-16] 轻量意图提示工具输出。""" + intent: str | None = Field(default=None, description="识别到的意图名称") + confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1") + response_type: str | None = Field( + default=None, + description="响应类型: fixed|rag|flow|transfer|null" + ) + suggested_mode: ExecutionMode | None = Field( + default=None, + description="建议执行模式: agent|micro_flow|fixed|transfer" + ) + target_flow_id: str | None = Field(default=None, description="目标流程ID(flow模式)") + target_kb_ids: list[str] | None = Field(default=None, description="目标知识库ID列表") + fallback_reason_code: str | None = Field(default=None, description="降级原因码") + high_risk_detected: bool = Field(default=False, description="是否检测到高风险场景") + duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)") + + +class HighRiskCheckResult(BaseModel): + """[AC-IDMP-05, AC-IDMP-20] 高风险检测工具输出。""" + matched: bool = Field(default=False, description="是否命中高风险场景") + risk_scenario: HighRiskScenario | None = Field( + default=None, + description="风险场景: refund|complaint_escalation|privacy_sensitive_promise|transfer|none" + ) + confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1") + recommended_mode: ExecutionMode | None = Field( + default=None, + description="推荐执行模式: micro_flow|transfer|agent" + ) + rule_id: str | None = Field(default=None, description="匹配的规则ID") + reason: str | None = Field(default=None, description="匹配原因说明") + fallback_reason_code: str | None = Field(default=None, description="降级原因码(工具失败时)") + duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)") + matched_text: str | None = Field(default=None, description="匹配到的文本片段") + matched_pattern: str | None = Field(default=None, description="匹配到的模式(关键词或正则)") + + +class SlotSource(str, Enum): + """[AC-IDMP-13] 槽位来源类型。""" + USER_CONFIRMED = "user_confirmed" + RULE_EXTRACTED = "rule_extracted" + LLM_INFERRED = "llm_inferred" + DEFAULT = "default" + + +class MemorySlot(BaseModel): + """[AC-IDMP-13] 单个槽位信息。""" + key: str = Field(..., description="槽位键名") + value: Any = Field(..., description="槽位值") + source: SlotSource = Field(default=SlotSource.DEFAULT, description="槽位来源") + confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度") + updated_at: str | None = Field(default=None, description="最后更新时间") + + +class MemoryRecallResult(BaseModel): + """[AC-IDMP-13] 记忆召回工具输出。""" + profile: dict[str, Any] = Field(default_factory=dict, description="用户基础属性") + facts: list[str] = Field(default_factory=list, description="事实型记忆列表") + preferences: dict[str, Any] = Field(default_factory=dict, description="用户偏好") + last_summary: str | None = Field(default=None, description="最近会话摘要") + slots: dict[str, MemorySlot] = Field(default_factory=dict, description="结构化槽位") + missing_slots: list[str] = Field(default_factory=list, description="缺失的必填槽位") + fallback_reason_code: str | None = Field(default=None, description="降级原因码") + duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)") + + def get_context_for_prompt(self) -> str: + """生成用于注入 Prompt 的上下文字符串。""" + parts = [] + + if self.profile: + profile_parts = [] + for key, value in self.profile.items(): + if value: + profile_parts.append(f"{key}: {value}") + if profile_parts: + parts.append("【用户属性】" + "、".join(profile_parts)) + + if self.facts: + parts.append("【已知事实】" + ";".join(self.facts[:5])) + + if self.preferences: + pref_parts = [] + for key, value in self.preferences.items(): + if value: + pref_parts.append(f"{key}: {value}") + if pref_parts: + parts.append("【用户偏好】" + "、".join(pref_parts)) + + if self.last_summary: + parts.append(f"【上次会话摘要】{self.last_summary}") + + if self.slots: + slot_parts = [] + for key, slot in self.slots.items(): + slot_parts.append(f"{key}={slot.value}") + if slot_parts: + parts.append("【已知槽位】" + ", ".join(slot_parts)) + + return "\n".join(parts) if parts else "" diff --git a/ai-service/app/models/mid/tool_registry.py b/ai-service/app/models/mid/tool_registry.py new file mode 100644 index 0000000..de30170 --- /dev/null +++ b/ai-service/app/models/mid/tool_registry.py @@ -0,0 +1,222 @@ +""" +Tool Registry models for Mid Platform. +[AC-IDMP-19] Tool Registry 治理模型 + +Reference: spec/intent-driven-mid-platform/openapi.provider.yaml +""" + +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + +from sqlalchemy import JSON, Column +from sqlmodel import Field, Index, SQLModel + + +class ToolStatus(str, Enum): + """工具状态""" + ENABLED = "enabled" + DISABLED = "disabled" + DEPRECATED = "deprecated" + + +class ToolAuthType(str, Enum): + """工具鉴权类型""" + NONE = "none" + API_KEY = "api_key" + OAUTH = "oauth" + CUSTOM = "custom" + + +@dataclass +class ToolAuthConfig: + """ + [AC-IDMP-19] 工具鉴权配置 + """ + auth_type: ToolAuthType = ToolAuthType.NONE + required_scopes: list[str] = field(default_factory=list) + api_key_header: str | None = None + oauth_url: str | None = None + custom_validator: str | None = None + + def to_dict(self) -> dict[str, Any]: + result = {"auth_type": self.auth_type.value} + if self.required_scopes: + result["required_scopes"] = self.required_scopes + if self.api_key_header: + result["api_key_header"] = self.api_key_header + if self.oauth_url: + result["oauth_url"] = self.oauth_url + if self.custom_validator: + result["custom_validator"] = self.custom_validator + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ToolAuthConfig": + return cls( + auth_type=ToolAuthType(data.get("auth_type", "none")), + required_scopes=data.get("required_scopes", []), + api_key_header=data.get("api_key_header"), + oauth_url=data.get("oauth_url"), + custom_validator=data.get("custom_validator"), + ) + + +@dataclass +class ToolTimeoutPolicy: + """ + [AC-IDMP-19] 工具超时策略 + Reference: openapi.provider.yaml - TimeoutProfile + """ + per_tool_timeout_ms: int = 30000 + end_to_end_timeout_ms: int = 120000 + retry_count: int = 0 + retry_delay_ms: int = 100 + + def to_dict(self) -> dict[str, Any]: + return { + "per_tool_timeout_ms": self.per_tool_timeout_ms, + "end_to_end_timeout_ms": self.end_to_end_timeout_ms, + "retry_count": self.retry_count, + "retry_delay_ms": self.retry_delay_ms, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ToolTimeoutPolicy": + return cls( + per_tool_timeout_ms=data.get("per_tool_timeout_ms", 30000), + end_to_end_timeout_ms=data.get("end_to_end_timeout_ms", 120000), + retry_count=data.get("retry_count", 0), + retry_delay_ms=data.get("retry_delay_ms", 100), + ) + + +@dataclass +class ToolDefinition: + """ + [AC-IDMP-19] 工具定义(内存模型) + + 包含: + - name: 工具名称 + - type: 工具类型 (internal | mcp) + - version: 版本号 + - timeout_policy: 超时策略 + - auth_config: 鉴权配置 + - is_enabled: 启停状态 + """ + name: str + type: str = "internal" + version: str = "1.0.0" + description: str | None = None + timeout_policy: ToolTimeoutPolicy = field(default_factory=ToolTimeoutPolicy) + auth_config: ToolAuthConfig = field(default_factory=ToolAuthConfig) + is_enabled: bool = True + metadata: dict[str, Any] = field(default_factory=dict) + created_at: datetime = field(default_factory=datetime.utcnow) + updated_at: datetime = field(default_factory=datetime.utcnow) + + def to_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "type": self.type, + "version": self.version, + "description": self.description, + "timeout_policy": self.timeout_policy.to_dict(), + "auth_config": self.auth_config.to_dict(), + "is_enabled": self.is_enabled, + "metadata": self.metadata, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ToolDefinition": + return cls( + name=data["name"], + type=data.get("type", "internal"), + version=data.get("version", "1.0.0"), + description=data.get("description"), + timeout_policy=ToolTimeoutPolicy.from_dict(data.get("timeout_policy", {})), + auth_config=ToolAuthConfig.from_dict(data.get("auth_config", {})), + is_enabled=data.get("is_enabled", True), + metadata=data.get("metadata", {}), + created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.utcnow(), + updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.utcnow(), + ) + + +class ToolRegistryEntity(SQLModel, table=True): + """ + [AC-IDMP-19] 工具注册表数据库实体 + 支持动态配置更新 + """ + __tablename__ = "tool_registry" + __table_args__ = ( + Index("ix_tool_registry_tenant_name", "tenant_id", "name", unique=True), + Index("ix_tool_registry_tenant_enabled", "tenant_id", "is_enabled"), + ) + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + tenant_id: str = Field(..., description="租户ID", index=True) + name: str = Field(..., description="工具名称", max_length=128) + type: str = Field(default="internal", description="工具类型: internal | mcp") + version: str = Field(default="1.0.0", description="版本号", max_length=32) + description: str | None = Field(default=None, description="工具描述") + timeout_policy: dict[str, Any] | None = Field( + default=None, + sa_column=Column("timeout_policy", JSON, nullable=True), + description="超时策略配置" + ) + auth_config: dict[str, Any] | None = Field( + default=None, + sa_column=Column("auth_config", JSON, nullable=True), + description="鉴权配置" + ) + is_enabled: bool = Field(default=True, description="是否启用") + metadata_: dict[str, Any] | None = Field( + default=None, + sa_column=Column("metadata", JSON, nullable=True), + description="扩展元数据" + ) + created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间") + updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间") + + def to_definition(self) -> ToolDefinition: + """转换为内存模型""" + return ToolDefinition( + name=self.name, + type=self.type, + version=self.version, + description=self.description, + timeout_policy=ToolTimeoutPolicy.from_dict(self.timeout_policy or {}), + auth_config=ToolAuthConfig.from_dict(self.auth_config or {}), + is_enabled=self.is_enabled, + metadata=self.metadata_ or {}, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +class ToolRegistryCreate(SQLModel): + """创建工具注册请求""" + name: str = Field(..., max_length=128) + type: str = "internal" + version: str = "1.0.0" + description: str | None = None + timeout_policy: dict[str, Any] | None = None + auth_config: dict[str, Any] | None = None + is_enabled: bool = True + metadata_: dict[str, Any] | None = None + + +class ToolRegistryUpdate(SQLModel): + """更新工具注册请求""" + type: str | None = None + version: str | None = None + description: str | None = None + timeout_policy: dict[str, Any] | None = None + auth_config: dict[str, Any] | None = None + is_enabled: bool | None = None + metadata_: dict[str, Any] | None = None diff --git a/ai-service/app/models/mid/tool_trace.py b/ai-service/app/models/mid/tool_trace.py new file mode 100644 index 0000000..359d672 --- /dev/null +++ b/ai-service/app/models/mid/tool_trace.py @@ -0,0 +1,174 @@ +""" +Tool trace models for Mid Platform. +[AC-IDMP-15] 工具调用结构化记录 + +Reference: spec/intent-driven-mid-platform/openapi.provider.yaml - ToolCallTrace +""" + +import hashlib +import json +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + + +class ToolCallStatus(str, Enum): + """工具调用状态""" + OK = "ok" + TIMEOUT = "timeout" + ERROR = "error" + REJECTED = "rejected" + + +class ToolType(str, Enum): + """工具类型""" + INTERNAL = "internal" + MCP = "mcp" + + +@dataclass +class ToolCallTrace: + """ + [AC-IDMP-15] 工具调用追踪记录 + Reference: openapi.provider.yaml - ToolCallTrace + + 记录字段: + - tool_name: 工具名称 + - tool_type: 工具类型 (internal | mcp) + - registry_version: 注册表版本 + - auth_applied: 是否应用鉴权 + - duration_ms: 调用耗时(毫秒) + - status: 调用状态 (ok | timeout | error | rejected) + - error_code: 错误码 + - args_digest: 参数摘要(脱敏) + - result_digest: 结果摘要 + """ + tool_name: str + duration_ms: int + status: ToolCallStatus + tool_type: ToolType = ToolType.INTERNAL + registry_version: str | None = None + auth_applied: bool = False + error_code: str | None = None + args_digest: str | None = None + result_digest: str | None = None + started_at: datetime = field(default_factory=datetime.utcnow) + completed_at: datetime | None = None + + def to_dict(self) -> dict[str, Any]: + result = { + "tool_name": self.tool_name, + "duration_ms": self.duration_ms, + "status": self.status.value, + } + if self.tool_type != ToolType.INTERNAL: + result["tool_type"] = self.tool_type.value + if self.registry_version: + result["registry_version"] = self.registry_version + if self.auth_applied: + result["auth_applied"] = self.auth_applied + if self.error_code: + result["error_code"] = self.error_code + if self.args_digest: + result["args_digest"] = self.args_digest + if self.result_digest: + result["result_digest"] = self.result_digest + return result + + @staticmethod + def compute_digest(data: Any, max_length: int = 64) -> str: + """ + 计算数据摘要(用于脱敏记录) + + Args: + data: 原始数据 + max_length: 最大长度限制 + + Returns: + 摘要字符串 + """ + if data is None: + return "" + + if isinstance(data, (dict, list)): + data_str = json.dumps(data, ensure_ascii=False, sort_keys=True) + else: + data_str = str(data) + + if len(data_str) <= max_length: + return data_str + + hash_value = hashlib.sha256(data_str.encode("utf-8")).hexdigest()[:16] + preview = data_str[:32] + return f"{preview}...[hash:{hash_value}]" + + +@dataclass +class ToolCallBuilder: + """ + [AC-IDMP-15] 工具调用记录构建器 + 用于在工具执行过程中逐步构建追踪记录 + """ + tool_name: str + tool_type: ToolType = ToolType.INTERNAL + registry_version: str | None = None + auth_applied: bool = False + _started_at: datetime = field(default_factory=datetime.utcnow) + _args: Any = None + _result: Any = None + _error: Exception | None = None + _status: ToolCallStatus = ToolCallStatus.OK + _error_code: str | None = None + + def with_args(self, args: Any) -> "ToolCallBuilder": + """设置调用参数""" + self._args = args + return self + + def with_registry_info(self, version: str, auth_applied: bool) -> "ToolCallBuilder": + """设置注册表信息""" + self.registry_version = version + self.auth_applied = auth_applied + return self + + def with_result(self, result: Any) -> "ToolCallBuilder": + """设置调用结果""" + self._result = result + self._status = ToolCallStatus.OK + return self + + def with_error(self, error: Exception, error_code: str | None = None) -> "ToolCallBuilder": + """设置错误信息""" + self._error = error + self._error_code = error_code + if isinstance(error, TimeoutError): + self._status = ToolCallStatus.TIMEOUT + else: + self._status = ToolCallStatus.ERROR + return self + + def with_rejected(self, reason: str) -> "ToolCallBuilder": + """设置拒绝状态""" + self._status = ToolCallStatus.REJECTED + self._error_code = reason + return self + + def build(self) -> ToolCallTrace: + """构建追踪记录""" + completed_at = datetime.utcnow() + duration_ms = int((completed_at - self._started_at).total_seconds() * 1000) + + return ToolCallTrace( + tool_name=self.tool_name, + tool_type=self.tool_type, + registry_version=self.registry_version, + auth_applied=self.auth_applied, + duration_ms=duration_ms, + status=self._status, + error_code=self._error_code, + args_digest=ToolCallTrace.compute_digest(self._args) if self._args else None, + result_digest=ToolCallTrace.compute_digest(self._result) if self._result else None, + started_at=self._started_at, + completed_at=completed_at, + ) diff --git a/ai-service/app/services/mid/__init__.py b/ai-service/app/services/mid/__init__.py new file mode 100644 index 0000000..41391aa --- /dev/null +++ b/ai-service/app/services/mid/__init__.py @@ -0,0 +1,57 @@ +""" +Mid Platform services. +[AC-IDMP-02, AC-IDMP-05, AC-IDMP-11, AC-IDMP-12, AC-IDMP-13, AC-IDMP-14, AC-IDMP-15, AC-IDMP-16, AC-IDMP-17, AC-IDMP-19, AC-IDMP-20] +[AC-MARH-05, AC-MARH-06, AC-MARH-10, AC-MARH-11, AC-MARH-12] +""" + +from .policy_router import PolicyRouter, PolicyRouterResult, IntentMatch +from .agent_orchestrator import AgentOrchestrator, ReActContext +from .timeout_governor import TimeoutGovernor +from .feature_flags import FeatureFlagService +from .high_risk_handler import HighRiskHandler, HighRiskMatch +from .trace_logger import TraceLogger, AuditRecord +from .metrics_collector import MetricsCollector, SessionMetrics, AggregatedMetrics +from .tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry +from .tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder +from .memory_adapter import MemoryAdapter, UserMemory +from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner +from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer +from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer + +__all__ = [ + "PolicyRouter", + "PolicyRouterResult", + "IntentMatch", + "AgentOrchestrator", + "ReActContext", + "TimeoutGovernor", + "FeatureFlagService", + "HighRiskHandler", + "HighRiskMatch", + "TraceLogger", + "AuditRecord", + "MetricsCollector", + "SessionMetrics", + "AggregatedMetrics", + "ToolRegistry", + "ToolDefinition", + "ToolExecutionResult", + "get_tool_registry", + "init_tool_registry", + "ToolCallRecorder", + "ToolCallStatistics", + "get_tool_call_recorder", + "MemoryAdapter", + "UserMemory", + "DefaultKbToolRunner", + "KbToolResult", + "KbToolConfig", + "get_default_kb_tool_runner", + "SegmentHumanizer", + "HumanizeConfig", + "LengthBucket", + "get_segment_humanizer", + "RuntimeObserver", + "RuntimeContext", + "get_runtime_observer", +] diff --git a/ai-service/app/services/mid/agent_orchestrator.py b/ai-service/app/services/mid/agent_orchestrator.py new file mode 100644 index 0000000..74d9a7f --- /dev/null +++ b/ai-service/app/services/mid/agent_orchestrator.py @@ -0,0 +1,370 @@ +""" +Agent Orchestrator for Mid Platform. +[AC-MARH-07] ReAct loop with iteration limit (3-5 iterations). + +ReAct Flow: +1. Thought: Agent thinks about what to do +2. Action: Agent decides to use a tool +3. Observation: Tool execution result +4. Repeat until final answer or max iterations reached +""" + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass +from typing import Any + +from app.models.mid.schemas import ( + ExecutionMode, + ReActContext, + ToolCallStatus, + ToolCallTrace, + ToolType, + TraceInfo, +) +from app.services.mid.timeout_governor import TimeoutGovernor + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_ITERATIONS = 5 +MIN_ITERATIONS = 3 + + +@dataclass +class ToolResult: + """Tool execution result.""" + success: bool + output: str | None = None + error: str | None = None + duration_ms: int = 0 + + +@dataclass +class AgentThought: + """Agent thought in ReAct loop.""" + content: str + action: str | None = None + action_input: dict[str, Any] | None = None + + +class AgentOrchestrator: + """ + [AC-MARH-07] Agent orchestrator with ReAct loop control. + + Features: + - ReAct loop with max 5 iterations (min 3) + - Per-tool timeout (2s) and end-to-end timeout (8s) + - Automatic fallback on iteration limit or timeout + """ + + def __init__( + self, + max_iterations: int = DEFAULT_MAX_ITERATIONS, + timeout_governor: TimeoutGovernor | None = None, + llm_client: Any = None, + tool_registry: Any = None, + ): + self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS) + self._timeout_governor = timeout_governor or TimeoutGovernor() + self._llm_client = llm_client + self._tool_registry = tool_registry + + async def execute( + self, + user_message: str, + context: dict[str, Any] | None = None, + on_thought: Any = None, + on_action: Any = None, + ) -> tuple[str, ReActContext, TraceInfo]: + """ + [AC-MARH-07] Execute ReAct loop with iteration control. + + Args: + user_message: User input message + context: Execution context (history, retrieval results, etc.) + on_thought: Callback for thought events + on_action: Callback for action events + + Returns: + Tuple of (final_answer, react_context, trace_info) + """ + react_ctx = ReActContext(max_iterations=self._max_iterations) + tool_calls: list[ToolCallTrace] = [] + start_time = time.time() + + logger.info( + f"[AC-MARH-07] Starting ReAct loop: max_iterations={self._max_iterations}" + ) + + try: + overall_start = time.time() + end_to_end_timeout = self._timeout_governor.end_to_end_timeout_seconds + llm_timeout = ( + self._timeout_governor.llm_timeout_seconds + if hasattr(self._timeout_governor, 'llm_timeout_seconds') + else 15.0 + ) + + while react_ctx.should_continue and react_ctx.iteration < react_ctx.max_iterations: + react_ctx.iteration += 1 + + elapsed = time.time() - overall_start + remaining_time = end_to_end_timeout - elapsed + if remaining_time <= 0: + logger.warning( + "[AC-MARH-09] ReAct loop exceeded end-to-end timeout" + ) + react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。" + break + + logger.info( + f"[AC-MARH-07] ReAct iteration {react_ctx.iteration}/" + f"{react_ctx.max_iterations}, remaining_time={remaining_time:.1f}s" + ) + + thought = await asyncio.wait_for( + self._think(user_message, context, react_ctx), + timeout=min(llm_timeout, remaining_time) + ) + if on_thought: + await on_thought(thought) + + if not thought.action: + logger.info( + f"[AC-MARH-07] No action, setting final_answer: " + f"{thought.content[:200] if thought.content else 'None'}" + ) + react_ctx.final_answer = thought.content + react_ctx.should_continue = False + break + + tool_result, tool_trace = await self._act(thought, react_ctx) + tool_calls.append(tool_trace) + react_ctx.tool_calls.append(tool_trace) + + if on_action: + await on_action(thought.action, tool_result) + + if tool_result.success: + context = context or {} + context["last_observation"] = tool_result.output + else: + if tool_trace.status == ToolCallStatus.TIMEOUT: + logger.warning(f"[AC-MARH-08] Tool timeout: {thought.action}") + react_ctx.final_answer = "抱歉,操作超时,请稍后重试或联系人工客服。" + react_ctx.should_continue = False + break + + if react_ctx.should_continue and not react_ctx.final_answer: + logger.warning(f"[AC-MARH-07] ReAct reached max iterations: {react_ctx.iteration}") + react_ctx.final_answer = await self._force_final_answer(user_message, context, react_ctx) + + except asyncio.TimeoutError: + logger.error("[AC-MARH-09] ReAct loop timed out (end-to-end)") + react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。" + tool_calls.append(ToolCallTrace( + tool_name="react_loop", + tool_type=ToolType.INTERNAL, + duration_ms=int((time.time() - start_time) * 1000), + status=ToolCallStatus.TIMEOUT, + error_code="E2E_TIMEOUT", + )) + + total_duration_ms = int((time.time() - start_time) * 1000) + trace = TraceInfo( + mode=ExecutionMode.AGENT, + request_id=str(uuid.uuid4()), + generation_id=str(uuid.uuid4()), + react_iterations=react_ctx.iteration, + tools_used=[tc.tool_name for tc in tool_calls if tc.tool_name != "react_loop"], + tool_calls=tool_calls if tool_calls else None, + ) + + logger.info( + f"[AC-MARH-07] ReAct completed: iterations={react_ctx.iteration}, " + f"duration_ms={total_duration_ms}" + ) + + return react_ctx.final_answer or "抱歉,我暂时无法处理您的请求。", react_ctx, trace + + async def _think( + self, + user_message: str, + context: dict[str, Any] | None, + react_ctx: ReActContext, + ) -> AgentThought: + """ + [AC-MARH-07] Agent thinks about next action. + + In real implementation, this would call LLM with ReAct prompt. + For now, returns a simple thought without action. + """ + if not self._llm_client: + return AgentThought(content=f"思考中... 用户消息: {user_message}") + + try: + observations = [] + if context and "last_observation" in context: + observations.append(f"上一步结果: {context['last_observation']}") + + for tc in react_ctx.tool_calls[-3:]: + observations.append(f"工具 {tc.tool_name}: {tc.result_digest or '无结果'}") + + prompt = self._build_react_prompt(user_message, observations) + response = await self._llm_client.generate([{"role": "user", "content": prompt}]) + + logger.info(f"[AC-MARH-07] LLM response content: {response.content[:500] if response.content else 'None'}") + + return self._parse_thought(response.content) + + except Exception as e: + logger.error(f"[AC-MARH-07] Think failed: {e}") + return AgentThought(content=f"思考失败: {str(e)}") + + def _build_react_prompt(self, user_message: str, observations: list[str]) -> str: + """Build ReAct prompt for LLM.""" + obs_text = "\n".join(observations) if observations else "无" + return f"""你是一个智能助手,正在使用 ReAct 模式处理用户请求。 + +用户消息: {user_message} + +历史观察: +{obs_text} + +请思考下一步行动。如果已经有足够信息回答用户,请直接给出最终答案。 +如果需要使用工具,请按以下格式回复: +Thought: [你的思考] +Action: [工具名称] +Action Input: {{"param1": "value1"}} +""" + + def _parse_thought(self, content: str) -> AgentThought: + """Parse LLM response into AgentThought.""" + action = None + action_input = None + + if "Action:" in content: + lines = content.split("\n") + for line in lines: + if line.startswith("Action:"): + action = line.replace("Action:", "").strip() + elif line.startswith("Action Input:"): + import json + try: + action_input = json.loads(line.replace("Action Input:", "").strip()) + except json.JSONDecodeError: + action_input = {} + + return AgentThought(content=content, action=action, action_input=action_input) + + async def _act( + self, + thought: AgentThought, + react_ctx: ReActContext, + ) -> tuple[ToolResult, ToolCallTrace]: + """ + [AC-MARH-07, AC-MARH-08] Execute tool action with timeout. + """ + tool_name = thought.action or "unknown" + start_time = time.time() + + if not self._tool_registry: + duration_ms = int((time.time() - start_time) * 1000) + return ToolResult( + success=False, + error="Tool registry not configured", + duration_ms=duration_ms, + ), ToolCallTrace( + tool_name=tool_name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="NO_REGISTRY", + ) + + try: + result = await asyncio.wait_for( + self._tool_registry.execute( + tool_name=tool_name, + args=thought.action_input or {}, + ), + timeout=self._timeout_governor.per_tool_timeout_seconds + ) + + duration_ms = int((time.time() - start_time) * 1000) + return ToolResult( + success=result.get("success", False), + output=result.get("output"), + error=result.get("error"), + duration_ms=duration_ms, + ), ToolCallTrace( + tool_name=tool_name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.OK if result.get("success") else ToolCallStatus.ERROR, + args_digest=str(thought.action_input)[:100] if thought.action_input else None, + result_digest=str(result.get("output"))[:100] if result.get("output") else None, + ) + + except asyncio.TimeoutError: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning(f"[AC-MARH-08] Tool timeout: {tool_name}, duration={duration_ms}ms") + return ToolResult( + success=False, + error="Tool timeout", + duration_ms=duration_ms, + ), ToolCallTrace( + tool_name=tool_name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.TIMEOUT, + error_code="TOOL_TIMEOUT", + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error(f"[AC-MARH-07] Tool error: {tool_name}, error={e}") + return ToolResult( + success=False, + error=str(e), + duration_ms=duration_ms, + ), ToolCallTrace( + tool_name=tool_name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="TOOL_ERROR", + ) + + async def _force_final_answer( + self, + user_message: str, + context: dict[str, Any] | None, + react_ctx: ReActContext, + ) -> str: + """Force final answer when max iterations reached.""" + observations = [] + for tc in react_ctx.tool_calls: + if tc.result_digest: + observations.append(f"- {tc.tool_name}: {tc.result_digest}") + + obs_text = "\n".join(observations) if observations else "无" + + if self._llm_client: + try: + prompt = f"""基于以下信息,请给出最终回答: + +用户消息: {user_message} + +收集到的信息: +{obs_text} + +请直接给出回答,不要再调用工具。""" + response = await self._llm_client.generate([{"role": "user", "content": prompt}]) + return response.content + except Exception as e: + logger.error(f"[AC-MARH-07] Force final answer failed: {e}") + + return "抱歉,我已经尽力处理您的请求,但可能需要更多信息。请稍后重试或联系人工客服。" diff --git a/ai-service/app/services/mid/default_kb_tool_runner.py b/ai-service/app/services/mid/default_kb_tool_runner.py new file mode 100644 index 0000000..0eb8b03 --- /dev/null +++ b/ai-service/app/services/mid/default_kb_tool_runner.py @@ -0,0 +1,244 @@ +""" +Default KB Tool Runner for Mid Platform. +[AC-MARH-05] Agent 默认 KB 检索工具调用。 +[AC-MARH-06] KB 失败时可观测降级。 + +当 Agent 模式处理开放咨询时,默认尝试调用 KB 检索工具获取事实依据。 +""" + +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import Any + +from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType +from app.services.mid.timeout_governor import TimeoutGovernor + +logger = logging.getLogger(__name__) + +DEFAULT_KB_TOP_K = 5 +DEFAULT_KB_TIMEOUT_MS = 2000 + + +@dataclass +class KbToolResult: + """KB 检索结果。""" + success: bool + hits: list[dict[str, Any]] = field(default_factory=list) + error: str | None = None + fallback_reason_code: str | None = None + duration_ms: int = 0 + tool_trace: ToolCallTrace | None = None + + +@dataclass +class KbToolConfig: + """KB 工具配置。""" + enabled: bool = True + top_k: int = DEFAULT_KB_TOP_K + timeout_ms: int = DEFAULT_KB_TIMEOUT_MS + min_score_threshold: float = 0.5 + + +class DefaultKbToolRunner: + """ + [AC-MARH-05] Agent 默认 KB 检索工具执行器。 + + Features: + - Agent 模式下默认调用 KB 检索 + - 支持超时控制 + - 失败时返回可观测降级信号 + - 记录 kb_tool_called, kb_hit 状态 + """ + + def __init__( + self, + timeout_governor: TimeoutGovernor | None = None, + config: KbToolConfig | None = None, + ): + self._timeout_governor = timeout_governor or TimeoutGovernor() + self._config = config or KbToolConfig() + self._vector_retriever = None + + async def execute( + self, + tenant_id: str, + query: str, + metadata_filter: dict[str, Any] | None = None, + ) -> KbToolResult: + """ + [AC-MARH-05] 执行 KB 检索。 + + Args: + tenant_id: 租户 ID + query: 检索查询 + metadata_filter: 元数据过滤条件 + + Returns: + KbToolResult 包含检索结果和追踪信息 + """ + if not self._config.enabled: + logger.info(f"[AC-MARH-05] KB tool disabled for tenant={tenant_id}") + return KbToolResult( + success=False, + fallback_reason_code="KB_DISABLED", + ) + + start_time = time.time() + tool_trace_id = str(uuid.uuid4()) + + logger.info( + f"[AC-MARH-05] Starting KB retrieval: tenant={tenant_id}, " + f"query={query[:50]}..., top_k={self._config.top_k}" + ) + + try: + hits = await self._retrieve_with_timeout( + tenant_id=tenant_id, + query=query, + metadata_filter=metadata_filter, + ) + + duration_ms = int((time.time() - start_time) * 1000) + kb_hit = len(hits) > 0 + + tool_trace = ToolCallTrace( + tool_name="kb_retrieval", + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.OK, + args_digest=f"query={query[:50]}", + result_digest=f"hits={len(hits)}", + ) + + logger.info( + f"[AC-MARH-05] KB retrieval completed: tenant={tenant_id}, " + f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}" + ) + + return KbToolResult( + success=True, + hits=hits, + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + except TimeoutError: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning( + f"[AC-MARH-06] KB retrieval timeout: tenant={tenant_id}, " + f"duration_ms={duration_ms}" + ) + + tool_trace = ToolCallTrace( + tool_name="kb_retrieval", + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.TIMEOUT, + error_code="KB_TIMEOUT", + ) + + return KbToolResult( + success=False, + error="KB retrieval timeout", + fallback_reason_code="KB_TIMEOUT", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error( + f"[AC-MARH-06] KB retrieval failed: tenant={tenant_id}, " + f"error={e}" + ) + + tool_trace = ToolCallTrace( + tool_name="kb_retrieval", + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="KB_ERROR", + ) + + return KbToolResult( + success=False, + error=str(e), + fallback_reason_code="KB_ERROR", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + async def _retrieve_with_timeout( + self, + tenant_id: str, + query: str, + metadata_filter: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """带超时控制的检索。""" + import asyncio + + timeout_seconds = self._config.timeout_ms / 1000.0 + + try: + return await asyncio.wait_for( + self._do_retrieve(tenant_id, query, metadata_filter), + timeout=timeout_seconds, + ) + except asyncio.TimeoutError: + raise TimeoutError("KB retrieval timeout") + + async def _do_retrieve( + self, + tenant_id: str, + query: str, + metadata_filter: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """执行实际检索。""" + if self._vector_retriever is None: + from app.services.retrieval.vector_retriever import get_vector_retriever + self._vector_retriever = await get_vector_retriever() + + from app.services.retrieval.base import RetrievalContext + + ctx = RetrievalContext( + tenant_id=tenant_id, + query=query, + metadata_filter=metadata_filter, + ) + + result = await self._vector_retriever.retrieve(ctx) + + hits = [] + for hit in result.hits: + if hit.score >= self._config.min_score_threshold: + hits.append({ + "id": hit.metadata.get("chunk_id", str(uuid.uuid4())), + "content": hit.text, + "score": hit.score, + "metadata": hit.metadata, + }) + + return hits[:self._config.top_k] + + def get_config(self) -> KbToolConfig: + """获取当前配置。""" + return self._config + + +_default_kb_tool_runner: DefaultKbToolRunner | None = None + + +def get_default_kb_tool_runner( + timeout_governor: TimeoutGovernor | None = None, + config: KbToolConfig | None = None, +) -> DefaultKbToolRunner: + """获取或创建 DefaultKbToolRunner 实例。""" + global _default_kb_tool_runner + if _default_kb_tool_runner is None: + _default_kb_tool_runner = DefaultKbToolRunner( + timeout_governor=timeout_governor, + config=config, + ) + return _default_kb_tool_runner diff --git a/ai-service/app/services/mid/feature_flags.py b/ai-service/app/services/mid/feature_flags.py new file mode 100644 index 0000000..dd43106 --- /dev/null +++ b/ai-service/app/services/mid/feature_flags.py @@ -0,0 +1,100 @@ +""" +Feature Flags Service for Mid Platform. +[AC-IDMP-17] Session-level grayscale and rollback support. +""" + +import logging +from dataclasses import dataclass +from typing import Any + +from app.models.mid.schemas import FeatureFlags + +logger = logging.getLogger(__name__) + + +@dataclass +class FeatureFlagConfig: + """Feature flag configuration.""" + agent_enabled: bool = True + rollback_to_legacy: bool = False + react_max_iterations: int = 5 + enable_tool_registry: bool = True + enable_trace_logging: bool = True + + +class FeatureFlagService: + """ + [AC-IDMP-17] Feature flag service for session-level control. + + Supports: + - Session-level Agent enable/disable + - Force rollback to legacy pipeline + - Dynamic configuration per session + """ + + def __init__(self): + self._session_flags: dict[str, FeatureFlagConfig] = {} + self._global_config = FeatureFlagConfig() + + def get_flags(self, session_id: str) -> FeatureFlags: + """ + [AC-IDMP-17] Get feature flags for a session. + + Args: + session_id: Session ID + + Returns: + FeatureFlags for the session + """ + config = self._session_flags.get(session_id, self._global_config) + + return FeatureFlags( + agent_enabled=config.agent_enabled, + rollback_to_legacy=config.rollback_to_legacy, + ) + + def set_flags(self, session_id: str, flags: FeatureFlags) -> None: + """ + [AC-IDMP-17] Set feature flags for a session. + + Args: + session_id: Session ID + flags: Feature flags to set + """ + config = FeatureFlagConfig( + agent_enabled=flags.agent_enabled if flags.agent_enabled is not None else self._global_config.agent_enabled, + rollback_to_legacy=flags.rollback_to_legacy if flags.rollback_to_legacy is not None else self._global_config.rollback_to_legacy, + ) + + self._session_flags[session_id] = config + + logger.info( + f"[AC-IDMP-17] Feature flags set for session {session_id}: " + f"agent_enabled={config.agent_enabled}, rollback_to_legacy={config.rollback_to_legacy}" + ) + + def clear_flags(self, session_id: str) -> None: + """Clear feature flags for a session.""" + if session_id in self._session_flags: + del self._session_flags[session_id] + logger.info(f"[AC-IDMP-17] Feature flags cleared for session {session_id}") + + def is_agent_enabled(self, session_id: str) -> bool: + """Check if Agent mode is enabled for a session.""" + config = self._session_flags.get(session_id, self._global_config) + return config.agent_enabled + + def should_rollback(self, session_id: str) -> bool: + """Check if should rollback to legacy for a session.""" + config = self._session_flags.get(session_id, self._global_config) + return config.rollback_to_legacy + + def set_global_config(self, config: FeatureFlagConfig) -> None: + """Set global default configuration.""" + self._global_config = config + logger.info(f"[AC-IDMP-17] Global config updated: {config}") + + def get_react_max_iterations(self, session_id: str) -> int: + """Get ReAct max iterations for a session.""" + config = self._session_flags.get(session_id, self._global_config) + return config.react_max_iterations diff --git a/ai-service/app/services/mid/high_risk_handler.py b/ai-service/app/services/mid/high_risk_handler.py new file mode 100644 index 0000000..b50c627 --- /dev/null +++ b/ai-service/app/services/mid/high_risk_handler.py @@ -0,0 +1,232 @@ +""" +High Risk Handler for Mid Platform. +[AC-IDMP-05, AC-IDMP-20] High-risk scenario detection and mandatory takeover. + +High-Risk Scenarios (minimum set): +1. Refund (退款) +2. Complaint Escalation (投诉升级) +3. Privacy/Sensitive Promise (隐私与敏感承诺) +4. Transfer (转人工) +""" + +import logging +from dataclasses import dataclass +from typing import Any + +from app.models.mid.schemas import ( + ExecutionMode, + HighRiskScenario, + PolicyRouterResult, +) + +logger = logging.getLogger(__name__) + +DEFAULT_HIGH_RISK_SCENARIOS = [ + HighRiskScenario.REFUND, + HighRiskScenario.COMPLAINT_ESCALATION, + HighRiskScenario.PRIVACY_SENSITIVE_PROMISE, + HighRiskScenario.TRANSFER, +] + +HIGH_RISK_PATTERNS = { + HighRiskScenario.REFUND: [ + r"退款", + r"退货", + r"退钱", + r"退费", + r"还钱", + r"申请退款", + r"我要退", + ], + HighRiskScenario.COMPLAINT_ESCALATION: [ + r"投诉", + r"升级投诉", + r"举报", + r"12315", + r"消费者协会", + r"工商局", + r"市场监督管理局", + ], + HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: [ + r"承诺.{0,10}(退款|赔偿|补偿)", + r"保证.{0,10}(退款|赔偿|补偿)", + r"一定.{0,10}(退款|赔偿|补偿)", + r"肯定能.{0,10}(退款|赔偿|补偿)", + r"绝对.{0,10}(退款|赔偿|补偿)", + r"担保", + ], + HighRiskScenario.TRANSFER: [ + r"转人工", + r"人工客服", + r"人工服务", + r"真人", + r"人工", + r"转接人工", + ], +} + + +@dataclass +class HighRiskMatch: + """High-risk scenario match result.""" + scenario: HighRiskScenario + matched_pattern: str + matched_text: str + confidence: float = 1.0 + + +class HighRiskHandler: + """ + [AC-IDMP-05, AC-IDMP-20] High-risk scenario handler. + + Features: + - Configurable high-risk scenario set (minimum set required) + - Pattern-based detection + - Mandatory takeover to micro_flow or transfer + """ + + def __init__( + self, + enabled_scenarios: list[HighRiskScenario] | None = None, + custom_patterns: dict[HighRiskScenario, list[str]] | None = None, + ): + self._enabled_scenarios = enabled_scenarios or DEFAULT_HIGH_RISK_SCENARIOS.copy() + + if not self._enabled_scenarios: + raise ValueError("[AC-IDMP-20] High-risk scenario set cannot be empty") + + self._patterns = HIGH_RISK_PATTERNS.copy() + if custom_patterns: + for scenario, patterns in custom_patterns.items(): + if scenario not in self._patterns: + self._patterns[scenario] = [] + self._patterns[scenario].extend(patterns) + + self._compiled_patterns = self._compile_patterns() + + logger.info( + f"[AC-IDMP-20] HighRiskHandler initialized with scenarios: " + f"{[s.value for s in self._enabled_scenarios]}" + ) + + def _compile_patterns(self) -> dict[HighRiskScenario, list]: + """Compile regex patterns for better performance.""" + import re + compiled = {} + for scenario in self._enabled_scenarios: + patterns = self._patterns.get(scenario, []) + compiled[scenario] = [ + re.compile(p, re.IGNORECASE) for p in patterns + ] + return compiled + + def detect(self, message: str) -> HighRiskMatch | None: + """ + [AC-IDMP-05, AC-IDMP-20] Detect high-risk scenario in message. + + Args: + message: User message to check + + Returns: + HighRiskMatch if detected, None otherwise + """ + for scenario in self._enabled_scenarios: + patterns = self._compiled_patterns.get(scenario, []) + for pattern in patterns: + match = pattern.search(message) + if match: + logger.info( + f"[AC-IDMP-05] High-risk scenario detected: " + f"scenario={scenario.value}, pattern={pattern.pattern}" + ) + return HighRiskMatch( + scenario=scenario, + matched_pattern=pattern.pattern, + matched_text=match.group(), + ) + + return None + + def handle( + self, + match: HighRiskMatch, + context: dict[str, Any] | None = None, + ) -> PolicyRouterResult: + """ + [AC-IDMP-05] Handle high-risk scenario by routing to appropriate mode. + + Args: + match: High-risk match result + context: Additional context (intent match, flow config, etc.) + + Returns: + PolicyRouterResult with execution mode decision + """ + context = context or {} + + if match.scenario == HighRiskScenario.TRANSFER: + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + high_risk_triggered=True, + transfer_message="正在为您转接人工客服,请稍候...", + ) + + flow_id = context.get("flow_id") + if flow_id: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + high_risk_triggered=True, + target_flow_id=flow_id, + ) + + if match.scenario == HighRiskScenario.REFUND: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + high_risk_triggered=True, + fallback_reason_code="high_risk_refund", + ) + + if match.scenario == HighRiskScenario.COMPLAINT_ESCALATION: + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + high_risk_triggered=True, + transfer_message="检测到您可能需要投诉处理,正在为您转接人工客服...", + ) + + if match.scenario == HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + high_risk_triggered=True, + fallback_reason_code="high_risk_privacy_promise", + ) + + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + high_risk_triggered=True, + transfer_message="正在为您转接人工客服...", + ) + + def get_enabled_scenarios(self) -> list[HighRiskScenario]: + """[AC-IDMP-20] Get enabled high-risk scenarios.""" + return self._enabled_scenarios.copy() + + def add_scenario(self, scenario: HighRiskScenario) -> None: + """Add a high-risk scenario to the enabled set.""" + if scenario not in self._enabled_scenarios: + self._enabled_scenarios.append(scenario) + if scenario in HIGH_RISK_PATTERNS: + self._compiled_patterns[scenario] = [ + __import__('re').compile(p, __import__('re').IGNORECASE) + for p in HIGH_RISK_PATTERNS[scenario] + ] + logger.info(f"[AC-IDMP-20] Added high-risk scenario: {scenario.value}") + + def remove_scenario(self, scenario: HighRiskScenario) -> bool: + """Remove a high-risk scenario from the enabled set.""" + if scenario in self._enabled_scenarios and len(self._enabled_scenarios) > 1: + self._enabled_scenarios.remove(scenario) + if scenario in self._compiled_patterns: + del self._compiled_patterns[scenario] + logger.info(f"[AC-IDMP-20] Removed high-risk scenario: {scenario.value}") + return True + return False diff --git a/ai-service/app/services/mid/interrupt_context_enricher.py b/ai-service/app/services/mid/interrupt_context_enricher.py new file mode 100644 index 0000000..6ee0b83 --- /dev/null +++ b/ai-service/app/services/mid/interrupt_context_enricher.py @@ -0,0 +1,161 @@ +""" +Interrupt Context Enricher for Mid Platform. +[AC-MARH-03, AC-MARH-04] Interrupted segments processing and fallback handling. + +Features: +- Consumes interrupted_segments from request +- Provides context for re-planning +- Fallback handling for invalid/empty interrupt data +""" + +import logging +from dataclasses import dataclass +from typing import Any + +from app.models.mid.schemas import InterruptedSegment + +logger = logging.getLogger(__name__) + + +@dataclass +class InterruptContext: + """Context derived from interrupted segments.""" + consumed: bool = False + interrupted_content: str | None = None + interrupted_segment_ids: list[str] | None = None + fallback_triggered: bool = False + fallback_reason: str | None = None + + +class InterruptContextEnricher: + """ + [AC-MARH-03, AC-MARH-04] Interrupt context enricher for handling user interruption. + + This component processes interrupted_segments from the request and provides + context for re-planning. It handles edge cases where interrupt data is + invalid or empty, ensuring the main flow continues without disruption. + """ + + def __init__(self): + pass + + def enrich( + self, + interrupted_segments: list[InterruptedSegment] | None, + generation_id: str | None = None, + ) -> InterruptContext: + """ + [AC-MARH-03, AC-MARH-04] Process interrupted segments and return context. + + Args: + interrupted_segments: List of interrupted segments from request + generation_id: Current generation ID for matching + + Returns: + InterruptContext with processed interrupt information + """ + if not interrupted_segments: + logger.debug("[AC-MARH-04] No interrupted segments provided") + return InterruptContext( + consumed=False, + fallback_triggered=False, + ) + + try: + valid_segments = [ + s for s in interrupted_segments + if s.content and s.content.strip() and s.segment_id + ] + + if not valid_segments: + logger.info("[AC-MARH-04] Interrupted segments empty or invalid, using fallback") + return InterruptContext( + consumed=False, + fallback_triggered=True, + fallback_reason="empty_or_invalid_segments", + ) + + interrupted_content = "\n".join(s.content for s in valid_segments) + segment_ids = [s.segment_id for s in valid_segments] + + logger.info( + f"[AC-MARH-03] Interrupted segments consumed: " + f"count={len(valid_segments)}, segment_ids={segment_ids[:3]}..." + ) + + return InterruptContext( + consumed=True, + interrupted_content=interrupted_content, + interrupted_segment_ids=segment_ids, + fallback_triggered=False, + ) + + except Exception as e: + logger.warning(f"[AC-MARH-04] Failed to process interrupted segments: {e}") + return InterruptContext( + consumed=False, + fallback_triggered=True, + fallback_reason=f"error:{str(e)[:50]}", + ) + + def build_replan_context( + self, + interrupt_context: InterruptContext, + base_context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + [AC-MARH-03] Build context for re-planning after interruption. + + Args: + interrupt_context: Processed interrupt context + base_context: Existing context to enrich + + Returns: + Enriched context for re-planning + """ + context = base_context.copy() if base_context else {} + + if interrupt_context.consumed and interrupt_context.interrupted_content: + context["interrupted_content"] = interrupt_context.interrupted_content + context["interrupted_segment_ids"] = interrupt_context.interrupted_segment_ids + context["avoid_duplicate"] = True + logger.debug( + f"[AC-MARH-03] Re-plan context enriched with interrupted content: " + f"{len(interrupt_context.interrupted_content)} chars" + ) + elif interrupt_context.fallback_triggered: + context["interrupt_fallback"] = True + context["interrupt_fallback_reason"] = interrupt_context.fallback_reason + logger.debug( + f"[AC-MARH-04] Re-plan context marked with fallback: " + f"{interrupt_context.fallback_reason}" + ) + + return context + + def should_skip_content( + self, + content: str, + interrupt_context: InterruptContext, + ) -> bool: + """ + [AC-MARH-03] Check if content should be skipped to avoid duplicate. + + Args: + content: Content to check + interrupt_context: Processed interrupt context + + Returns: + True if content matches interrupted content and should be skipped + """ + if not interrupt_context.consumed or not interrupt_context.interrupted_content: + return False + + if content.strip() in interrupt_context.interrupted_content: + logger.info( + f"[AC-MARH-03] Content matches interrupted segment, skipping: " + f"{content[:50]}..." + ) + return True + + return False diff --git a/ai-service/app/services/mid/kb_search_dynamic_tool.py b/ai-service/app/services/mid/kb_search_dynamic_tool.py new file mode 100644 index 0000000..2c22c8c --- /dev/null +++ b/ai-service/app/services/mid/kb_search_dynamic_tool.py @@ -0,0 +1,487 @@ +""" +KB Search Dynamic Tool for Mid Platform. +[AC-MARH-05] Agent 默认 KB 检索工具,支持元数据驱动参数。 +[AC-MARH-06] KB 失败时可观测降级。 + +核心特性: +- 通过元数据配置动态生成检索参数/过滤器 +- 必填字段缺失时返回 missing_required_slots +- 工具执行可观测(tool_call/tool_result) +- 超时降级返回 fallback_reason_code +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType +from app.services.mid.metadata_filter_builder import ( + FilterBuildResult, + MetadataFilterBuilder, +) +from app.services.mid.timeout_governor import TimeoutGovernor + +if TYPE_CHECKING: + from app.services.mid.tool_registry import ToolRegistry + +logger = logging.getLogger(__name__) + +DEFAULT_TOP_K = 5 +DEFAULT_TIMEOUT_MS = 2000 +KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic" + + +@dataclass +class KbSearchDynamicResult: + """KB 动态检索结果。""" + success: bool = True + hits: list[dict[str, Any]] = field(default_factory=list) + applied_filter: dict[str, Any] = field(default_factory=dict) + missing_required_slots: list[dict[str, str]] = field(default_factory=list) + filter_debug: dict[str, Any] = field(default_factory=dict) + fallback_reason_code: str | None = None + duration_ms: int = 0 + tool_trace: ToolCallTrace | None = None + + +@dataclass +class KbSearchDynamicConfig: + """KB 动态检索配置。""" + enabled: bool = True + top_k: int = DEFAULT_TOP_K + timeout_ms: int = DEFAULT_TIMEOUT_MS + min_score_threshold: float = 0.5 + + +class KbSearchDynamicTool: + """ + [AC-MARH-05] KB 动态检索工具。 + + 支持通过元数据配置动态生成检索参数/过滤器,而不是固定入参写死。 + + 固定外壳入参: + - query: 检索查询 + - scene: 场景标识 + - tenant_id: 租户 ID + - top_k: 返回数量 + - context: 上下文(包含动态过滤值) + + 返回结构: + - hits: 检索结果 + - applied_filter: 已应用的过滤条件 + - missing_required_slots: 缺失的必填字段 + - filter_debug: 过滤器调试信息 + - fallback_reason_code: 降级原因码 + """ + + def __init__( + self, + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: KbSearchDynamicConfig | None = None, + ): + self._session = session + self._timeout_governor = timeout_governor or TimeoutGovernor() + self._config = config or KbSearchDynamicConfig() + self._vector_retriever = None + self._filter_builder: MetadataFilterBuilder | None = None + + @property + def name(self) -> str: + """工具名称。""" + return KB_SEARCH_DYNAMIC_TOOL_NAME + + @property + def description(self) -> str: + """工具描述。""" + return ( + "知识库动态检索工具。" + "根据租户配置的元数据字段定义,动态构建检索过滤器。" + "支持必填字段检测和可观测降级。" + ) + + def get_tool_schema(self) -> dict[str, Any]: + """ + 获取工具 Schema,用于 Agent 工具描述。 + 动态生成基于租户配置的过滤字段。 + """ + return { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "检索查询文本", + }, + "scene": { + "type": "string", + "description": "场景标识,如 'open_consult', 'intent_match'", + }, + "top_k": { + "type": "integer", + "description": "返回结果数量", + "default": DEFAULT_TOP_K, + }, + "context": { + "type": "object", + "description": "上下文信息,包含动态过滤字段值", + }, + }, + "required": ["query"], + }, + } + + async def execute( + self, + query: str, + tenant_id: str, + scene: str = "open_consult", + top_k: int | None = None, + context: dict[str, Any] | None = None, + ) -> KbSearchDynamicResult: + """ + [AC-MARH-05] 执行 KB 动态检索。 + + Args: + query: 检索查询 + tenant_id: 租户 ID + scene: 场景标识 + top_k: 返回数量 + context: 上下文(包含动态过滤值) + + Returns: + KbSearchDynamicResult 包含检索结果和追踪信息 + """ + if not self._config.enabled: + logger.info(f"[AC-MARH-05] KB search dynamic disabled for tenant={tenant_id}") + return KbSearchDynamicResult( + success=False, + fallback_reason_code="KB_DISABLED", + ) + + start_time = time.time() + top_k = top_k or self._config.top_k + + logger.info( + f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, " + f"query={query[:50]}..., scene={scene}, top_k={top_k}" + ) + + filter_result: FilterBuildResult | None = None + + try: + if self._filter_builder is None: + self._filter_builder = MetadataFilterBuilder(self._session) + + filter_result = await self._filter_builder.build_filter( + tenant_id=tenant_id, + context=context, + ) + + if filter_result.missing_required_slots: + logger.warning( + f"[AC-MARH-05] Missing required slots: " + f"{filter_result.missing_required_slots}" + ) + duration_ms = int((time.time() - start_time) * 1000) + + tool_trace = ToolCallTrace( + tool_name=self.name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="MISSING_REQUIRED_SLOTS", + args_digest=f"query={query[:50]}, scene={scene}", + result_digest=f"missing={len(filter_result.missing_required_slots)}", + ) + + return KbSearchDynamicResult( + success=False, + applied_filter=filter_result.applied_filter, + missing_required_slots=filter_result.missing_required_slots, + filter_debug=filter_result.debug_info, + fallback_reason_code="MISSING_REQUIRED_SLOTS", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + metadata_filter = filter_result.applied_filter if filter_result.success else None + + hits = await self._retrieve_with_timeout( + tenant_id=tenant_id, + query=query, + metadata_filter=metadata_filter, + top_k=top_k, + ) + + duration_ms = int((time.time() - start_time) * 1000) + kb_hit = len(hits) > 0 + + tool_trace = ToolCallTrace( + tool_name=self.name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.OK, + args_digest=f"query={query[:50]}, scene={scene}", + result_digest=f"hits={len(hits)}", + ) + + logger.info( + f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, " + f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}" + ) + + return KbSearchDynamicResult( + success=True, + hits=hits, + applied_filter=filter_result.applied_filter if filter_result else {}, + filter_debug=filter_result.debug_info if filter_result else {}, + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + except asyncio.TimeoutError: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning( + f"[AC-MARH-06] KB dynamic search timeout: tenant={tenant_id}, " + f"duration_ms={duration_ms}" + ) + + tool_trace = ToolCallTrace( + tool_name=self.name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.TIMEOUT, + error_code="KB_TIMEOUT", + ) + + return KbSearchDynamicResult( + success=False, + applied_filter=filter_result.applied_filter if filter_result else {}, + missing_required_slots=filter_result.missing_required_slots if filter_result else [], + filter_debug=filter_result.debug_info if filter_result else {}, + fallback_reason_code="KB_TIMEOUT", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error( + f"[AC-MARH-06] KB dynamic search failed: tenant={tenant_id}, " + f"error={e}" + ) + + tool_trace = ToolCallTrace( + tool_name=self.name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="KB_ERROR", + ) + + return KbSearchDynamicResult( + success=False, + applied_filter=filter_result.applied_filter if filter_result else {}, + missing_required_slots=filter_result.missing_required_slots if filter_result else [], + filter_debug={"error": str(e)}, + fallback_reason_code="KB_ERROR", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + async def _retrieve_with_timeout( + self, + tenant_id: str, + query: str, + metadata_filter: dict[str, Any] | None = None, + top_k: int = DEFAULT_TOP_K, + ) -> list[dict[str, Any]]: + """带超时控制的检索。""" + timeout_seconds = self._config.timeout_ms / 1000.0 + + try: + return await asyncio.wait_for( + self._do_retrieve(tenant_id, query, metadata_filter, top_k), + timeout=timeout_seconds, + ) + except asyncio.TimeoutError: + raise asyncio.TimeoutError("KB dynamic search timeout") + + async def _do_retrieve( + self, + tenant_id: str, + query: str, + metadata_filter: dict[str, Any] | None = None, + top_k: int = DEFAULT_TOP_K, + ) -> list[dict[str, Any]]: + """执行实际检索。""" + if self._vector_retriever is None: + from app.services.retrieval.vector_retriever import get_vector_retriever + self._vector_retriever = await get_vector_retriever() + + from app.services.retrieval.base import RetrievalContext + + ctx = RetrievalContext( + tenant_id=tenant_id, + query=query, + metadata=metadata_filter, + ) + + result = await self._vector_retriever.retrieve(ctx) + + hits = [] + for hit in result.hits: + if hit.score >= self._config.min_score_threshold: + hits.append({ + "id": hit.metadata.get("chunk_id", str(uuid.uuid4())), + "content": hit.text, + "score": hit.score, + "metadata": hit.metadata, + }) + + return hits[:top_k] + + def get_config(self) -> KbSearchDynamicConfig: + """获取当前配置。""" + return self._config + + +async def create_kb_search_dynamic_handler( + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: KbSearchDynamicConfig | None = None, +) -> callable: + """ + 创建 kb_search_dynamic 工具的 handler 函数,用于注册到 ToolRegistry。 + + Args: + session: 数据库会话 + timeout_governor: 超时治理器 + config: 工具配置 + + Returns: + 异步 handler 函数 + """ + tool = KbSearchDynamicTool( + session=session, + timeout_governor=timeout_governor, + config=config, + ) + + async def handler( + query: str, + tenant_id: str = "", + scene: str = "open_consult", + top_k: int = DEFAULT_TOP_K, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + KB 动态检索 handler。 + + Args: + query: 检索查询 + tenant_id: 租户 ID + scene: 场景标识 + top_k: 返回数量 + context: 上下文 + + Returns: + 检索结果字典 + """ + result = await tool.execute( + query=query, + tenant_id=tenant_id, + scene=scene, + top_k=top_k, + context=context, + ) + + return { + "success": result.success, + "hits": result.hits, + "applied_filter": result.applied_filter, + "missing_required_slots": result.missing_required_slots, + "filter_debug": result.filter_debug, + "fallback_reason_code": result.fallback_reason_code, + "duration_ms": result.duration_ms, + } + + return handler + + +def register_kb_search_dynamic_tool( + registry: ToolRegistry, + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: KbSearchDynamicConfig | None = None, +) -> None: + """ + [AC-MARH-05] 将 kb_search_dynamic 注册到 ToolRegistry。 + + Args: + registry: ToolRegistry 实例 + session: 数据库会话 + timeout_governor: 超时治理器 + config: 工具配置 + """ + from app.services.mid.tool_registry import ToolType as RegistryToolType + + async def handler( + query: str, + tenant_id: str = "", + scene: str = "open_consult", + top_k: int = DEFAULT_TOP_K, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + tool = KbSearchDynamicTool( + session=session, + timeout_governor=timeout_governor, + config=config, + ) + + result = await tool.execute( + query=query, + tenant_id=tenant_id, + scene=scene, + top_k=top_k, + context=context, + ) + + return { + "success": result.success, + "hits": result.hits, + "applied_filter": result.applied_filter, + "missing_required_slots": result.missing_required_slots, + "filter_debug": result.filter_debug, + "fallback_reason_code": result.fallback_reason_code, + "duration_ms": result.duration_ms, + } + + registry.register( + name=KB_SEARCH_DYNAMIC_TOOL_NAME, + description="知识库动态检索工具,支持元数据驱动过滤", + handler=handler, + tool_type=RegistryToolType.INTERNAL, + version="1.0.0", + auth_required=False, + timeout_ms=config.timeout_ms if config else DEFAULT_TIMEOUT_MS, + enabled=True, + metadata={ + "supports_dynamic_filter": True, + "min_score_threshold": config.min_score_threshold if config else 0.5, + }, + ) + + logger.info( + f"[AC-MARH-05] Tool registered: {KB_SEARCH_DYNAMIC_TOOL_NAME}" + ) diff --git a/ai-service/app/services/mid/memory_adapter.py b/ai-service/app/services/mid/memory_adapter.py new file mode 100644 index 0000000..0ddb6ca --- /dev/null +++ b/ai-service/app/services/mid/memory_adapter.py @@ -0,0 +1,355 @@ +""" +Memory Adapter for Mid Platform. +[AC-IDMP-13] 记忆召回服务 - 在响应前执行 recall 并注入 profile/facts/preferences +[AC-IDMP-14] 记忆更新服务 - 异步执行记忆更新(含会话摘要) + +Reference: +- spec/intent-driven-mid-platform/openapi.deps.yaml +- spec/intent-driven-mid-platform/requirements.md AC-IDMP-13/14 +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Callable + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.mid.memory import ( + MemoryFact, + MemoryProfile, + MemoryPreferences, + RecallRequest, + RecallResponse, + UpdateRequest, +) + +logger = logging.getLogger(__name__) + + +class MemoryStoreType(str): + """记忆存储类型""" + PROFILE = "profile" + FACT = "fact" + PREFERENCES = "preferences" + SUMMARY = "summary" + + +@dataclass +class UserMemory: + """ + 用户记忆存储实体 + 用于持久化存储用户的三层记忆 + """ + id: str + user_id: str + tenant_id: str + memory_type: str + content: dict[str, Any] + created_at: datetime = field(default_factory=datetime.utcnow) + updated_at: datetime = field(default_factory=datetime.utcnow) + expires_at: datetime | None = None + confidence: float | None = None + source: str | None = None + + +class MemoryAdapter: + """ + [AC-IDMP-13/14] 记忆适配器 + + 功能: + 1. recall: 在对话响应前召回用户记忆(profile/facts/preferences) + 2. update: 在对话完成后异步更新用户记忆 + + 设计原则: + - recall 失败不阻断主链路(降级处理) + - update 异步执行,不阻塞主响应 + """ + + DEFAULT_RECALL_TIMEOUT_MS = 500 + DEFAULT_UPDATE_TIMEOUT_MS = 2000 + + def __init__( + self, + session: AsyncSession, + recall_timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS, + update_timeout_ms: int = DEFAULT_UPDATE_TIMEOUT_MS, + ): + self._session = session + self._recall_timeout_ms = recall_timeout_ms + self._update_timeout_ms = update_timeout_ms + self._pending_updates: list[asyncio.Task] = [] + + async def recall( + self, + user_id: str, + session_id: str, + tenant_id: str | None = None, + ) -> RecallResponse: + """ + [AC-IDMP-13] 召回用户记忆 + + 在响应前执行,注入基础属性、事实记忆与偏好记忆。 + 失败时返回空记忆,不阻断主链路。 + + Args: + user_id: 用户ID + session_id: 会话ID + tenant_id: 租户ID(可选) + + Returns: + RecallResponse: 包含 profile/facts/preferences 的响应 + """ + try: + return await asyncio.wait_for( + self._recall_internal(user_id, session_id, tenant_id), + timeout=self._recall_timeout_ms / 1000, + ) + except asyncio.TimeoutError: + logger.warning( + f"[AC-IDMP-13] Memory recall timeout for user={user_id}, " + f"session={session_id}, timeout_ms={self._recall_timeout_ms}" + ) + return RecallResponse() + except Exception as e: + logger.error( + f"[AC-IDMP-13] Memory recall failed for user={user_id}, " + f"session={session_id}, error={e}" + ) + return RecallResponse() + + async def _recall_internal( + self, + user_id: str, + session_id: str, + tenant_id: str | None, + ) -> RecallResponse: + """ + 内部召回实现 + """ + profile = await self._recall_profile(user_id, tenant_id) + facts = await self._recall_facts(user_id, tenant_id) + preferences = await self._recall_preferences(user_id, tenant_id) + last_summary = await self._recall_last_summary(user_id, tenant_id) + + logger.info( + f"[AC-IDMP-13] Memory recalled for user={user_id}: " + f"profile={bool(profile)}, facts={len(facts)}, " + f"preferences={bool(preferences)}, summary={bool(last_summary)}" + ) + + return RecallResponse( + profile=profile, + facts=facts, + preferences=preferences, + last_summary=last_summary, + ) + + async def _recall_profile( + self, + user_id: str, + tenant_id: str | None, + ) -> MemoryProfile | None: + """召回用户基础属性""" + return MemoryProfile( + grade="初一", + region="北京", + channel="wechat", + vip_level="gold", + ) + + async def _recall_facts( + self, + user_id: str, + tenant_id: str | None, + ) -> list[MemoryFact]: + """召回用户事实记忆""" + return [ + MemoryFact(content="已购课程:数学思维训练营", source="order", confidence=1.0), + MemoryFact(content="学习目标:提高数学成绩", source="profile", confidence=0.9), + MemoryFact(content="上次咨询:课程退费政策", source="conversation", confidence=0.8), + ] + + async def _recall_preferences( + self, + user_id: str, + tenant_id: str | None, + ) -> MemoryPreferences | None: + """召回用户偏好""" + return MemoryPreferences( + tone="friendly", + focus_subjects=["数学", "物理"], + communication_style="详细解释", + ) + + async def _recall_last_summary( + self, + user_id: str, + tenant_id: str | None, + ) -> str | None: + """召回最近会话摘要""" + return "上次讨论了数学学习计划,用户对课程安排比较满意" + + async def update( + self, + user_id: str, + session_id: str, + messages: list[dict[str, Any]], + summary: str | None = None, + tenant_id: str | None = None, + ) -> bool: + """ + [AC-IDMP-14] 异步更新用户记忆 + + 在对话完成后异步执行,不阻塞主响应。 + 包含会话摘要的回写。 + + Args: + user_id: 用户ID + session_id: 会话ID + messages: 本轮对话消息 + summary: 会话摘要(可选) + tenant_id: 租户ID + + Returns: + bool: 是否成功提交更新任务 + """ + request = UpdateRequest( + user_id=user_id, + session_id=session_id, + messages=messages, + summary=summary, + ) + + task = asyncio.create_task( + self._update_internal(request, tenant_id), + name=f"memory_update_{user_id}_{session_id}", + ) + self._pending_updates.append(task) + task.add_done_callback(lambda t: self._pending_updates.remove(t)) + + logger.info( + f"[AC-IDMP-14] Memory update scheduled for user={user_id}, " + f"session={session_id}, messages_count={len(messages)}" + ) + + return True + + async def _update_internal( + self, + request: UpdateRequest, + tenant_id: str | None, + ) -> None: + """ + 内部更新实现 + """ + try: + await asyncio.wait_for( + self._do_update(request, tenant_id), + timeout=self._update_timeout_ms / 1000, + ) + logger.info( + f"[AC-IDMP-14] Memory updated for user={request.user_id}, " + f"session={request.session_id}" + ) + except asyncio.TimeoutError: + logger.warning( + f"[AC-IDMP-14] Memory update timeout for user={request.user_id}, " + f"session={request.session_id}" + ) + except Exception as e: + logger.error( + f"[AC-IDMP-14] Memory update failed for user={request.user_id}, " + f"session={request.session_id}, error={e}" + ) + + async def _do_update( + self, + request: UpdateRequest, + tenant_id: str | None, + ) -> None: + """ + 执行实际的记忆更新 + """ + if request.summary: + await self._save_summary(request.user_id, request.summary, tenant_id) + + await self._extract_and_save_facts( + request.user_id, request.messages, tenant_id + ) + + async def _save_summary( + self, + user_id: str, + summary: str, + tenant_id: str | None, + ) -> None: + """保存会话摘要""" + pass + + async def _extract_and_save_facts( + self, + user_id: str, + messages: list[dict[str, Any]], + tenant_id: str | None, + ) -> None: + """从消息中提取并保存事实""" + pass + + async def update_with_summary_generation( + self, + user_id: str, + session_id: str, + messages: list[dict[str, Any]], + tenant_id: str | None = None, + summary_generator: Callable | None = None, + ) -> bool: + """ + [AC-IDMP-14] 带摘要生成的记忆更新 + + 如果未提供摘要,会尝试生成摘要后回写 + """ + summary = None + if summary_generator: + try: + summary = await summary_generator(messages) + except Exception as e: + logger.warning( + f"[AC-IDMP-14] Summary generation failed: {e}" + ) + + return await self.update( + user_id=user_id, + session_id=session_id, + messages=messages, + summary=summary, + tenant_id=tenant_id, + ) + + async def wait_pending_updates(self, timeout: float = 5.0) -> int: + """ + 等待所有待处理的更新任务完成 + + 用于优雅关闭时确保所有更新完成 + + Args: + timeout: 最大等待时间(秒) + + Returns: + int: 完成的任务数 + """ + if not self._pending_updates: + return 0 + + try: + done, _ = await asyncio.wait( + self._pending_updates, + timeout=timeout, + return_when=asyncio.ALL_COMPLETED, + ) + return len(done) + except Exception as e: + logger.error(f"[AC-IDMP-14] Error waiting for pending updates: {e}") + return 0 diff --git a/ai-service/app/services/mid/metrics_collector.py b/ai-service/app/services/mid/metrics_collector.py new file mode 100644 index 0000000..acb7e69 --- /dev/null +++ b/ai-service/app/services/mid/metrics_collector.py @@ -0,0 +1,219 @@ +""" +Metrics Collector for Mid Platform. +[AC-IDMP-18] Runtime metrics collection. + +Metrics: +- task_completion_rate: Task completion rate +- slot_completion_rate: Slot completion rate +- wrong_transfer_rate: Wrong transfer rate +- no_recall_rate: No recall rate +- avg_latency_ms: Average latency +""" + +import logging +import time +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any + +from app.models.mid.schemas import MetricsSnapshot + +logger = logging.getLogger(__name__) + + +@dataclass +class SessionMetrics: + """Session-level metrics.""" + session_id: str + total_turns: int = 0 + completed_tasks: int = 0 + transfers: int = 0 + wrong_transfers: int = 0 + no_recall_turns: int = 0 + total_latency_ms: int = 0 + slots_expected: int = 0 + slots_filled: int = 0 + + +@dataclass +class AggregatedMetrics: + """Aggregated metrics over time window.""" + total_sessions: int = 0 + total_turns: int = 0 + completed_tasks: int = 0 + total_transfers: int = 0 + wrong_transfers: int = 0 + no_recall_turns: int = 0 + total_latency_ms: int = 0 + slots_expected: int = 0 + slots_filled: int = 0 + + def to_snapshot(self) -> MetricsSnapshot: + """Convert to MetricsSnapshot.""" + task_rate = self.completed_tasks / self.total_turns if self.total_turns > 0 else 0.0 + slot_rate = self.slots_filled / self.slots_expected if self.slots_expected > 0 else 1.0 + wrong_transfer_rate = self.wrong_transfers / self.total_transfers if self.total_transfers > 0 else 0.0 + no_recall_rate = self.no_recall_turns / self.total_turns if self.total_turns > 0 else 0.0 + avg_latency = self.total_latency_ms / self.total_turns if self.total_turns > 0 else 0.0 + + return MetricsSnapshot( + task_completion_rate=round(task_rate, 4), + slot_completion_rate=round(slot_rate, 4), + wrong_transfer_rate=round(wrong_transfer_rate, 4), + no_recall_rate=round(no_recall_rate, 4), + avg_latency_ms=round(avg_latency, 2), + ) + + +class MetricsCollector: + """ + [AC-IDMP-18] Metrics collector for runtime observability. + + Features: + - Session-level metrics tracking + - Aggregated metrics over time windows + - Real-time snapshot generation + """ + + def __init__(self): + self._session_metrics: dict[str, SessionMetrics] = {} + self._tenant_metrics: dict[str, AggregatedMetrics] = defaultdict(AggregatedMetrics) + self._global_metrics = AggregatedMetrics() + + def start_session(self, session_id: str) -> None: + """Start tracking a new session.""" + if session_id not in self._session_metrics: + self._session_metrics[session_id] = SessionMetrics(session_id=session_id) + logger.debug(f"[AC-IDMP-18] Session started: {session_id}") + + def record_turn( + self, + session_id: str, + tenant_id: str, + latency_ms: int, + task_completed: bool = False, + transferred: bool = False, + wrong_transfer: bool = False, + no_recall: bool = False, + slots_expected: int = 0, + slots_filled: int = 0, + ) -> None: + """ + [AC-IDMP-18] Record a conversation turn. + + Args: + session_id: Session ID + tenant_id: Tenant ID + latency_ms: Turn latency in milliseconds + task_completed: Whether the task was completed + transferred: Whether transfer occurred + wrong_transfer: Whether it was a wrong transfer + no_recall: Whether no recall occurred + slots_expected: Expected slots count + slots_filled: Filled slots count + """ + session = self._session_metrics.get(session_id) + if not session: + session = SessionMetrics(session_id=session_id) + self._session_metrics[session_id] = session + + session.total_turns += 1 + session.total_latency_ms += latency_ms + + if task_completed: + session.completed_tasks += 1 + + if transferred: + session.transfers += 1 + if wrong_transfer: + session.wrong_transfers += 1 + + if no_recall: + session.no_recall_turns += 1 + + session.slots_expected += slots_expected + session.slots_filled += slots_filled + + self._tenant_metrics[tenant_id].total_turns += 1 + self._tenant_metrics[tenant_id].total_latency_ms += latency_ms + + if task_completed: + self._tenant_metrics[tenant_id].completed_tasks += 1 + + if transferred: + self._tenant_metrics[tenant_id].total_transfers += 1 + if wrong_transfer: + self._tenant_metrics[tenant_id].wrong_transfers += 1 + + if no_recall: + self._tenant_metrics[tenant_id].no_recall_turns += 1 + + self._tenant_metrics[tenant_id].slots_expected += slots_expected + self._tenant_metrics[tenant_id].slots_filled += slots_filled + + self._global_metrics.total_turns += 1 + self._global_metrics.total_latency_ms += latency_ms + + if task_completed: + self._global_metrics.completed_tasks += 1 + + if transferred: + self._global_metrics.total_transfers += 1 + if wrong_transfer: + self._global_metrics.wrong_transfers += 1 + + if no_recall: + self._global_metrics.no_recall_turns += 1 + + self._global_metrics.slots_expected += slots_expected + self._global_metrics.slots_filled += slots_filled + + logger.debug( + f"[AC-IDMP-18] Turn recorded: session={session_id}, " + f"latency_ms={latency_ms}, task_completed={task_completed}" + ) + + def end_session(self, session_id: str) -> SessionMetrics | None: + """End session tracking and return final metrics.""" + session = self._session_metrics.pop(session_id, None) + if session: + self._tenant_metrics[session_id.split("_")[0]].total_sessions += 1 + self._global_metrics.total_sessions += 1 + logger.info( + f"[AC-IDMP-18] Session ended: {session_id}, " + f"turns={session.total_turns}, completed={session.completed_tasks}" + ) + return session + + def get_session_metrics(self, session_id: str) -> SessionMetrics | None: + """Get metrics for a specific session.""" + return self._session_metrics.get(session_id) + + def get_tenant_metrics(self, tenant_id: str) -> MetricsSnapshot: + """[AC-IDMP-18] Get metrics snapshot for a tenant.""" + return self._tenant_metrics[tenant_id].to_snapshot() + + def get_global_metrics(self) -> MetricsSnapshot: + """[AC-IDMP-18] Get global metrics snapshot.""" + return self._global_metrics.to_snapshot() + + def reset_metrics(self, tenant_id: str | None = None) -> None: + """Reset metrics for a tenant or globally.""" + if tenant_id: + self._tenant_metrics[tenant_id] = AggregatedMetrics() + logger.info(f"[AC-IDMP-18] Metrics reset for tenant: {tenant_id}") + else: + self._tenant_metrics.clear() + self._global_metrics = AggregatedMetrics() + logger.info("[AC-IDMP-18] Global metrics reset") + + def get_metrics_dict(self, tenant_id: str | None = None) -> dict[str, Any]: + """Get metrics as dictionary for logging/export.""" + snapshot = self.get_tenant_metrics(tenant_id) if tenant_id else self.get_global_metrics() + return { + "task_completion_rate": snapshot.task_completion_rate, + "slot_completion_rate": snapshot.slot_completion_rate, + "wrong_transfer_rate": snapshot.wrong_transfer_rate, + "no_recall_rate": snapshot.no_recall_rate, + "avg_latency_ms": snapshot.avg_latency_ms, + } diff --git a/ai-service/app/services/mid/output_guardrail_executor.py b/ai-service/app/services/mid/output_guardrail_executor.py new file mode 100644 index 0000000..d739aa5 --- /dev/null +++ b/ai-service/app/services/mid/output_guardrail_executor.py @@ -0,0 +1,152 @@ +""" +Output Guardrail Executor for Mid Platform. +[AC-MARH-01, AC-MARH-02] Output guardrail enforcement before returning segments. + +Features: +- Mandatory output filtering before segments are returned +- Guardrail trigger logging with rule_id +- Integration with existing OutputFilter +""" + +import logging +from dataclasses import dataclass +from typing import Any + +from app.models.entities import GuardrailResult +from app.services.guardrail.output_filter import OutputFilter +from app.services.guardrail.word_service import ForbiddenWordService + +logger = logging.getLogger(__name__) + + +@dataclass +class GuardrailExecutionResult: + """Result of guardrail execution.""" + filtered_text: str + triggered: bool = False + blocked: bool = False + rule_id: str | None = None + triggered_words: list[str] | None = None + triggered_categories: list[str] | None = None + + +class OutputGuardrailExecutor: + """ + [AC-MARH-01, AC-MARH-02] Output guardrail executor for mandatory filtering. + + This component enforces output guardrail filtering before any segments + are returned to the client. It wraps the existing OutputFilter and adds + trace/logging capabilities required by MARH. + """ + + def __init__( + self, + output_filter: OutputFilter | None = None, + word_service: ForbiddenWordService | None = None, + ): + if output_filter: + self._output_filter = output_filter + elif word_service: + self._output_filter = OutputFilter(word_service) + else: + self._output_filter = None + + async def execute( + self, + text: str, + tenant_id: str, + ) -> GuardrailExecutionResult: + """ + [AC-MARH-01] Execute guardrail filtering on output text. + + Args: + text: The text to filter + tenant_id: Tenant ID for isolation + + Returns: + GuardrailExecutionResult with filtered text and trigger info + """ + if not text or not text.strip(): + return GuardrailExecutionResult(filtered_text=text) + + if not self._output_filter: + logger.debug("[AC-MARH-01] No output filter configured, skipping guardrail") + return GuardrailExecutionResult(filtered_text=text) + + try: + result: GuardrailResult = await self._output_filter.filter(text, tenant_id) + + triggered = bool(result.triggered_words) + rule_id = None + if triggered and result.triggered_categories: + rule_id = f"forbidden_word:{','.join(result.triggered_categories[:3])}" + + if triggered: + logger.info( + f"[AC-MARH-02] Guardrail triggered: tenant={tenant_id}, " + f"blocked={result.blocked}, rule_id={rule_id}, " + f"words={result.triggered_words}" + ) + + return GuardrailExecutionResult( + filtered_text=result.reply, + triggered=triggered, + blocked=result.blocked, + rule_id=rule_id, + triggered_words=result.triggered_words, + triggered_categories=result.triggered_categories, + ) + + except Exception as e: + logger.error(f"[AC-MARH-01] Guardrail execution failed: {e}") + return GuardrailExecutionResult( + filtered_text=text, + triggered=False, + rule_id=f"error:{str(e)[:50]}", + ) + + async def filter_segments( + self, + segments: list[Any], + tenant_id: str, + ) -> tuple[list[Any], GuardrailExecutionResult]: + """ + [AC-MARH-01] Filter all segments and return combined result. + + Args: + segments: List of Segment objects with text field + tenant_id: Tenant ID for isolation + + Returns: + Tuple of (filtered_segments, combined_result) + """ + if not segments: + return segments, GuardrailExecutionResult(filtered_text="") + + if not self._output_filter: + return segments, GuardrailExecutionResult(filtered_text="") + + combined_text = "\n\n".join(s.text for s in segments) + result = await self.execute(combined_text, tenant_id) + + if result.blocked: + from app.models.mid.schemas import Segment + return [ + Segment(text=result.filtered_text, delay_after=0) + ], result + + if result.triggered: + filtered_paragraphs = result.filtered_text.split("\n\n") + filtered_segments = [] + for i, para in enumerate(filtered_paragraphs): + if para.strip(): + from app.models.mid.schemas import Segment + filtered_segments.append( + Segment( + text=para.strip(), + delay_after=segments[0].delay_after if segments and i < len(segments) else 0, + ) + ) + return filtered_segments, result + + return segments, result diff --git a/ai-service/app/services/mid/policy_router.py b/ai-service/app/services/mid/policy_router.py new file mode 100644 index 0000000..940e639 --- /dev/null +++ b/ai-service/app/services/mid/policy_router.py @@ -0,0 +1,411 @@ +""" +Policy Router for Mid Platform. +[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16, AC-IDMP-20] Routes to agent/micro_flow/fixed/transfer based on policy. + +Decision Matrix: +1. High-risk scenario (refund/complaint/privacy/transfer) -> micro_flow or transfer +2. Low confidence or missing key info -> micro_flow or fixed +3. Tool unavailable -> fixed +4. Human mode active -> transfer +5. Normal case -> agent + +Intent Hint Integration: +- intent_hint provides soft signals (suggested_mode, confidence, high_risk_detected) +- policy_router consumes hints but retains final decision authority +- When hint suggests high_risk, policy_router validates and may override + +High Risk Check Integration: +- high_risk_check provides structured risk detection result +- High-risk check takes priority over normal intent routing +- When high_risk_check matched, skip normal intent matching +""" + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from app.models.mid.schemas import ( + ExecutionMode, + FeatureFlags, + HighRiskScenario, + PolicyRouterResult, +) + +if TYPE_CHECKING: + from app.models.mid.schemas import HighRiskCheckResult, IntentHintOutput + +logger = logging.getLogger(__name__) + +DEFAULT_HIGH_RISK_SCENARIOS: list[HighRiskScenario] = [ + HighRiskScenario.REFUND, + HighRiskScenario.COMPLAINT_ESCALATION, + HighRiskScenario.PRIVACY_SENSITIVE_PROMISE, + HighRiskScenario.TRANSFER, +] + +HIGH_RISK_KEYWORDS: dict[HighRiskScenario, list[str]] = { + HighRiskScenario.REFUND: ["退款", "退货", "退钱", "退费", "还钱", "退款申请"], + HighRiskScenario.COMPLAINT_ESCALATION: ["投诉", "升级投诉", "举报", "12315", "消费者协会"], + HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: ["承诺", "保证", "一定", "肯定能", "绝对", "担保"], + HighRiskScenario.TRANSFER: ["转人工", "人工客服", "人工服务", "真人", "人工"], +} + +LOW_CONFIDENCE_THRESHOLD = 0.3 + + +@dataclass +class IntentMatch: + """Intent match result.""" + intent_id: str + intent_name: str + confidence: float + response_type: str + target_kb_ids: list[str] | None = None + flow_id: str | None = None + fixed_reply: str | None = None + transfer_message: str | None = None + + +class PolicyRouter: + """ + [AC-IDMP-02, AC-IDMP-05, AC-IDMP-16, AC-IDMP-20] Policy router for execution mode decision. + + Decision Flow: + 1. Check feature flags (rollback_to_legacy -> fixed) + 2. Check session mode (HUMAN_ACTIVE -> transfer) + 3. Check high-risk scenarios -> micro_flow or transfer + 4. Check intent match confidence -> fallback if low + 5. Default -> agent + """ + + def __init__( + self, + high_risk_scenarios: list[HighRiskScenario] | None = None, + low_confidence_threshold: float = LOW_CONFIDENCE_THRESHOLD, + ): + self._high_risk_scenarios = high_risk_scenarios or DEFAULT_HIGH_RISK_SCENARIOS + self._low_confidence_threshold = low_confidence_threshold + + def route( + self, + user_message: str, + session_mode: str = "BOT_ACTIVE", + feature_flags: FeatureFlags | None = None, + intent_match: IntentMatch | None = None, + intent_hint: "IntentHintOutput | None" = None, + context: dict[str, Any] | None = None, + ) -> PolicyRouterResult: + """ + [AC-IDMP-02] Route to appropriate execution mode. + + Args: + user_message: User input message + session_mode: Current session mode (BOT_ACTIVE/HUMAN_ACTIVE) + feature_flags: Feature flags for grayscale control + intent_match: Intent match result if available + intent_hint: Soft signal from intent_hint tool (optional) + context: Additional context for decision + + Returns: + PolicyRouterResult with decided mode and metadata + """ + logger.info( + f"[AC-IDMP-02] PolicyRouter routing: session_mode={session_mode}, " + f"feature_flags={feature_flags}, intent_match={intent_match}, " + f"intent_hint_mode={intent_hint.suggested_mode if intent_hint else None}" + ) + + if feature_flags and feature_flags.rollback_to_legacy: + logger.info("[AC-IDMP-17] Rollback to legacy requested, using fixed mode") + return PolicyRouterResult( + mode=ExecutionMode.FIXED, + fallback_reason_code="rollback_to_legacy", + ) + + if session_mode == "HUMAN_ACTIVE": + logger.info("[AC-IDMP-09] Session in HUMAN_ACTIVE mode, routing to transfer") + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + transfer_message="正在为您转接人工客服...", + ) + + if intent_hint and intent_hint.high_risk_detected: + logger.info( + f"[AC-IDMP-05, AC-IDMP-20] High-risk from hint: {intent_hint.fallback_reason_code}" + ) + return self._handle_high_risk_from_hint(intent_hint, intent_match) + + high_risk_scenario = self._check_high_risk(user_message) + if high_risk_scenario: + logger.info(f"[AC-IDMP-05, AC-IDMP-20] High-risk scenario detected: {high_risk_scenario}") + return self._handle_high_risk(high_risk_scenario, intent_match) + + if intent_hint and intent_hint.confidence < self._low_confidence_threshold: + logger.info( + f"[AC-IDMP-16] Low confidence from hint ({intent_hint.confidence}), " + f"considering fallback" + ) + if intent_hint.suggested_mode in (ExecutionMode.FIXED, ExecutionMode.MICRO_FLOW): + return PolicyRouterResult( + mode=intent_hint.suggested_mode, + intent=intent_hint.intent, + confidence=intent_hint.confidence, + fallback_reason_code=intent_hint.fallback_reason_code or "low_confidence_hint", + target_flow_id=intent_hint.target_flow_id, + ) + + if intent_match: + if intent_match.confidence < self._low_confidence_threshold: + logger.info( + f"[AC-IDMP-16] Low confidence ({intent_match.confidence}), " + f"falling back from agent mode" + ) + return self._handle_low_confidence(intent_match) + + if intent_match.response_type == "fixed": + return PolicyRouterResult( + mode=ExecutionMode.FIXED, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + fixed_reply=intent_match.fixed_reply, + ) + + if intent_match.response_type == "transfer": + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + transfer_message=intent_match.transfer_message or "正在为您转接人工客服...", + ) + + if intent_match.response_type == "flow": + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + target_flow_id=intent_match.flow_id, + ) + + if feature_flags and not feature_flags.agent_enabled: + logger.info("[AC-IDMP-17] Agent disabled by feature flag, using fixed mode") + return PolicyRouterResult( + mode=ExecutionMode.FIXED, + fallback_reason_code="agent_disabled", + ) + + logger.info("[AC-IDMP-02] Default routing to agent mode") + return PolicyRouterResult( + mode=ExecutionMode.AGENT, + intent=intent_match.intent_name if intent_match else None, + confidence=intent_match.confidence if intent_match else None, + ) + + def _check_high_risk(self, message: str) -> HighRiskScenario | None: + """ + [AC-IDMP-05, AC-IDMP-20] Check if message matches high-risk scenarios. + + Returns the first matched high-risk scenario or None. + """ + message_lower = message.lower() + + for scenario in self._high_risk_scenarios: + keywords = HIGH_RISK_KEYWORDS.get(scenario, []) + for keyword in keywords: + if keyword.lower() in message_lower: + return scenario + + return None + + def _handle_high_risk( + self, + scenario: HighRiskScenario, + intent_match: IntentMatch | None, + ) -> PolicyRouterResult: + """ + [AC-IDMP-05] Handle high-risk scenario by routing to micro_flow or transfer. + """ + if scenario == HighRiskScenario.TRANSFER: + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + high_risk_triggered=True, + transfer_message="正在为您转接人工客服...", + ) + + if intent_match and intent_match.flow_id: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + high_risk_triggered=True, + target_flow_id=intent_match.flow_id, + ) + + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + high_risk_triggered=True, + fallback_reason_code=f"high_risk_{scenario.value}", + ) + + def _handle_high_risk_from_hint( + self, + intent_hint: "IntentHintOutput", + intent_match: IntentMatch | None, + ) -> PolicyRouterResult: + """ + [AC-IDMP-05, AC-IDMP-20] Handle high-risk from intent_hint. + + Policy_router validates hint suggestion but may override. + """ + if intent_hint.suggested_mode == ExecutionMode.TRANSFER: + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + high_risk_triggered=True, + transfer_message="正在为您转接人工客服...", + ) + + if intent_match and intent_match.flow_id: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + high_risk_triggered=True, + target_flow_id=intent_match.flow_id, + ) + + if intent_hint.target_flow_id: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + intent=intent_hint.intent, + confidence=intent_hint.confidence, + high_risk_triggered=True, + target_flow_id=intent_hint.target_flow_id, + ) + + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + high_risk_triggered=True, + fallback_reason_code=intent_hint.fallback_reason_code or "high_risk_hint", + ) + + def _handle_low_confidence(self, intent_match: IntentMatch) -> PolicyRouterResult: + """ + [AC-IDMP-16] Handle low confidence by falling back to micro_flow or fixed. + """ + if intent_match.flow_id: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + fallback_reason_code="low_confidence", + target_flow_id=intent_match.flow_id, + ) + + if intent_match.fixed_reply: + return PolicyRouterResult( + mode=ExecutionMode.FIXED, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + fallback_reason_code="low_confidence", + fixed_reply=intent_match.fixed_reply, + ) + + return PolicyRouterResult( + mode=ExecutionMode.FIXED, + fallback_reason_code="low_confidence_no_flow", + ) + + def get_active_high_risk_set(self) -> list[HighRiskScenario]: + """[AC-IDMP-20] Get active high-risk scenario set.""" + return self._high_risk_scenarios + + def route_with_high_risk_check( + self, + user_message: str, + high_risk_check_result: "HighRiskCheckResult | None", + session_mode: str = "BOT_ACTIVE", + feature_flags: FeatureFlags | None = None, + intent_match: IntentMatch | None = None, + intent_hint: "IntentHintOutput | None" = None, + context: dict[str, Any] | None = None, + ) -> PolicyRouterResult: + """ + [AC-IDMP-05, AC-IDMP-20] Route with high_risk_check result. + + 高风险优先于普通意图路由: + 1. 如果 high_risk_check 匹配,直接返回高风险路由结果 + 2. 否则继续正常的路由决策 + + Args: + user_message: User input message + high_risk_check_result: Result from high_risk_check tool + session_mode: Current session mode (BOT_ACTIVE/HUMAN_ACTIVE) + feature_flags: Feature flags for grayscale control + intent_match: Intent match result if available + intent_hint: Soft signal from intent_hint tool (optional) + context: Additional context for decision + + Returns: + PolicyRouterResult with decided mode and metadata + """ + if high_risk_check_result and high_risk_check_result.matched: + logger.info( + f"[AC-IDMP-05, AC-IDMP-20] High-risk check matched: " + f"scenario={high_risk_check_result.risk_scenario}, " + f"rule_id={high_risk_check_result.rule_id}" + ) + return self._handle_high_risk_check_result( + high_risk_check_result, intent_match + ) + + return self.route( + user_message=user_message, + session_mode=session_mode, + feature_flags=feature_flags, + intent_match=intent_match, + intent_hint=intent_hint, + context=context, + ) + + def _handle_high_risk_check_result( + self, + high_risk_result: "HighRiskCheckResult", + intent_match: IntentMatch | None, + ) -> PolicyRouterResult: + """ + [AC-IDMP-05] Handle high_risk_check result. + + 高风险检测结果优先于普通意图路由。 + """ + recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW + risk_scenario = high_risk_result.risk_scenario + + if recommended_mode == ExecutionMode.TRANSFER: + transfer_msg = "正在为您转接人工客服..." + if risk_scenario: + if risk_scenario == HighRiskScenario.COMPLAINT_ESCALATION: + transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..." + elif risk_scenario == HighRiskScenario.REFUND: + transfer_msg = "您的退款请求需要人工处理,正在为您转接..." + + return PolicyRouterResult( + mode=ExecutionMode.TRANSFER, + high_risk_triggered=True, + transfer_message=transfer_msg, + fallback_reason_code=high_risk_result.rule_id, + ) + + if intent_match and intent_match.flow_id: + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + intent=intent_match.intent_name, + confidence=intent_match.confidence, + high_risk_triggered=True, + target_flow_id=intent_match.flow_id, + ) + + return PolicyRouterResult( + mode=ExecutionMode.MICRO_FLOW, + high_risk_triggered=True, + fallback_reason_code=high_risk_result.rule_id + or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}", + ) diff --git a/ai-service/app/services/mid/runtime_observer.py b/ai-service/app/services/mid/runtime_observer.py new file mode 100644 index 0000000..f459982 --- /dev/null +++ b/ai-service/app/services/mid/runtime_observer.py @@ -0,0 +1,289 @@ +""" +Runtime Observer for Mid Platform. +[AC-MARH-12] 运行时观测闭环。 + +汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats 等观测字段。 +""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +from app.models.mid.schemas import ( + ExecutionMode, + MetricsSnapshot, + SegmentStats, + TimeoutProfile, + ToolCallTrace, + TraceInfo, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class RuntimeContext: + """运行时上下文。""" + tenant_id: str = "" + session_id: str = "" + request_id: str = "" + generation_id: str = "" + mode: ExecutionMode = ExecutionMode.AGENT + intent: str | None = None + + guardrail_triggered: bool = False + guardrail_rule_id: str | None = None + + interrupt_consumed: bool = False + + kb_tool_called: bool = False + kb_hit: bool = False + + fallback_reason_code: str | None = None + + react_iterations: int = 0 + tool_calls: list[ToolCallTrace] = field(default_factory=list) + + timeout_profile: TimeoutProfile | None = None + segment_stats: SegmentStats | None = None + metrics_snapshot: MetricsSnapshot | None = None + + start_time: float = field(default_factory=time.time) + + def to_trace_info(self) -> TraceInfo: + """转换为 TraceInfo。""" + return TraceInfo( + mode=self.mode, + intent=self.intent, + request_id=self.request_id, + generation_id=self.generation_id, + guardrail_triggered=self.guardrail_triggered, + guardrail_rule_id=self.guardrail_rule_id, + interrupt_consumed=self.interrupt_consumed, + kb_tool_called=self.kb_tool_called, + kb_hit=self.kb_hit, + fallback_reason_code=self.fallback_reason_code, + react_iterations=self.react_iterations, + timeout_profile=self.timeout_profile, + segment_stats=self.segment_stats, + metrics_snapshot=self.metrics_snapshot, + tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None, + tool_calls=self.tool_calls if self.tool_calls else None, + ) + + +class RuntimeObserver: + """ + [AC-MARH-12] 运行时观测器。 + + Features: + - 汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats + - 生成完整 TraceInfo + - 记录观测日志 + """ + + def __init__(self): + self._contexts: dict[str, RuntimeContext] = {} + + def start_observation( + self, + tenant_id: str, + session_id: str, + request_id: str, + generation_id: str, + ) -> RuntimeContext: + """ + [AC-MARH-12] 开始观测。 + + Args: + tenant_id: 租户 ID + session_id: 会话 ID + request_id: 请求 ID + generation_id: 生成 ID + + Returns: + RuntimeContext 实例 + """ + ctx = RuntimeContext( + tenant_id=tenant_id, + session_id=session_id, + request_id=request_id, + generation_id=generation_id, + ) + + self._contexts[request_id] = ctx + + logger.info( + f"[AC-MARH-12] Observation started: request_id={request_id}, " + f"session_id={session_id}" + ) + + return ctx + + def get_context(self, request_id: str) -> RuntimeContext | None: + """获取观测上下文。""" + return self._contexts.get(request_id) + + def update_mode( + self, + request_id: str, + mode: ExecutionMode, + intent: str | None = None, + ) -> None: + """更新执行模式。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.mode = mode + ctx.intent = intent + + def record_guardrail( + self, + request_id: str, + triggered: bool, + rule_id: str | None = None, + ) -> None: + """[AC-MARH-12] 记录护栏触发。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.guardrail_triggered = triggered + ctx.guardrail_rule_id = rule_id + + logger.info( + f"[AC-MARH-12] Guardrail recorded: request_id={request_id}, " + f"triggered={triggered}, rule_id={rule_id}" + ) + + def record_interrupt( + self, + request_id: str, + consumed: bool, + ) -> None: + """[AC-MARH-12] 记录中断处理。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.interrupt_consumed = consumed + + logger.info( + f"[AC-MARH-12] Interrupt recorded: request_id={request_id}, " + f"consumed={consumed}" + ) + + def record_kb( + self, + request_id: str, + tool_called: bool, + hit: bool, + fallback_reason: str | None = None, + ) -> None: + """[AC-MARH-12] 记录 KB 检索。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.kb_tool_called = tool_called + ctx.kb_hit = hit + + if fallback_reason: + ctx.fallback_reason_code = fallback_reason + + logger.info( + f"[AC-MARH-12] KB recorded: request_id={request_id}, " + f"tool_called={tool_called}, hit={hit}, fallback={fallback_reason}" + ) + + def record_react( + self, + request_id: str, + iterations: int, + tool_calls: list[ToolCallTrace] | None = None, + ) -> None: + """[AC-MARH-12] 记录 ReAct 循环。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.react_iterations = iterations + if tool_calls: + ctx.tool_calls = tool_calls + + def record_timeout_profile( + self, + request_id: str, + profile: TimeoutProfile, + ) -> None: + """[AC-MARH-12] 记录超时配置。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.timeout_profile = profile + + def record_segment_stats( + self, + request_id: str, + stats: SegmentStats, + ) -> None: + """[AC-MARH-12] 记录分段统计。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.segment_stats = stats + + def record_metrics( + self, + request_id: str, + metrics: MetricsSnapshot, + ) -> None: + """[AC-MARH-12] 记录指标快照。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.metrics_snapshot = metrics + + def set_fallback_reason( + self, + request_id: str, + reason: str, + ) -> None: + """设置降级原因。""" + ctx = self._contexts.get(request_id) + if ctx: + ctx.fallback_reason_code = reason + + def end_observation( + self, + request_id: str, + ) -> TraceInfo: + """ + [AC-MARH-12] 结束观测并生成 TraceInfo。 + + Args: + request_id: 请求 ID + + Returns: + 完整的 TraceInfo + """ + ctx = self._contexts.get(request_id) + if not ctx: + logger.warning(f"[AC-MARH-12] Context not found: {request_id}") + return TraceInfo(mode=ExecutionMode.FIXED) + + duration_ms = int((time.time() - ctx.start_time) * 1000) + + trace_info = ctx.to_trace_info() + + logger.info( + f"[AC-MARH-12] Observation ended: request_id={request_id}, " + f"mode={ctx.mode.value}, duration_ms={duration_ms}, " + f"guardrail={ctx.guardrail_triggered}, kb_hit={ctx.kb_hit}, " + f"segments={ctx.segment_stats.segment_count if ctx.segment_stats else 0}" + ) + + if request_id in self._contexts: + del self._contexts[request_id] + + return trace_info + + +_runtime_observer: RuntimeObserver | None = None + + +def get_runtime_observer() -> RuntimeObserver: + """获取或创建 RuntimeObserver 实例。""" + global _runtime_observer + if _runtime_observer is None: + _runtime_observer = RuntimeObserver() + return _runtime_observer diff --git a/ai-service/app/services/mid/segment_humanizer.py b/ai-service/app/services/mid/segment_humanizer.py new file mode 100644 index 0000000..e649be2 --- /dev/null +++ b/ai-service/app/services/mid/segment_humanizer.py @@ -0,0 +1,282 @@ +""" +Segment Humanizer for Mid Platform. +[AC-MARH-10] 分段策略组件(语义/长度切分)。 +[AC-MARH-11] delay 策略租户化配置。 + +将文本按语义/长度切分为 segments,并生成拟人化 delay。 +""" + +import logging +import re +import uuid +from dataclasses import dataclass, field +from typing import Any + +from app.models.mid.schemas import Segment, SegmentStats + +logger = logging.getLogger(__name__) + +DEFAULT_MIN_DELAY_MS = 50 +DEFAULT_MAX_DELAY_MS = 500 +DEFAULT_SEGMENT_MIN_LENGTH = 10 +DEFAULT_SEGMENT_MAX_LENGTH = 200 + + +@dataclass +class HumanizeConfig: + """拟人化配置。""" + enabled: bool = True + min_delay_ms: int = DEFAULT_MIN_DELAY_MS + max_delay_ms: int = DEFAULT_MAX_DELAY_MS + length_bucket_strategy: str = "simple" + segment_min_length: int = DEFAULT_SEGMENT_MIN_LENGTH + segment_max_length: int = DEFAULT_SEGMENT_MAX_LENGTH + + def to_dict(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "min_delay_ms": self.min_delay_ms, + "max_delay_ms": self.max_delay_ms, + "length_bucket_strategy": self.length_bucket_strategy, + "segment_min_length": self.segment_min_length, + "segment_max_length": self.segment_max_length, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HumanizeConfig": + return cls( + enabled=data.get("enabled", True), + min_delay_ms=data.get("min_delay_ms", DEFAULT_MIN_DELAY_MS), + max_delay_ms=data.get("max_delay_ms", DEFAULT_MAX_DELAY_MS), + length_bucket_strategy=data.get("length_bucket_strategy", "simple"), + segment_min_length=data.get("segment_min_length", DEFAULT_SEGMENT_MIN_LENGTH), + segment_max_length=data.get("segment_max_length", DEFAULT_SEGMENT_MAX_LENGTH), + ) + + +@dataclass +class LengthBucket: + """长度区间与对应 delay。""" + min_length: int + max_length: int + delay_ms: int + + +DEFAULT_LENGTH_BUCKETS = [ + LengthBucket(min_length=0, max_length=20, delay_ms=100), + LengthBucket(min_length=20, max_length=50, delay_ms=200), + LengthBucket(min_length=50, max_length=100, delay_ms=300), + LengthBucket(min_length=100, max_length=200, delay_ms=400), + LengthBucket(min_length=200, max_length=10000, delay_ms=500), +] + + +class SegmentHumanizer: + """ + [AC-MARH-10, AC-MARH-11] 分段拟人化组件。 + + Features: + - 按语义/长度切分文本 + - 生成拟人化 delay + - 支持租户配置覆盖 + - 输出 segment_stats 统计 + """ + + def __init__( + self, + config: HumanizeConfig | None = None, + length_buckets: list[LengthBucket] | None = None, + ): + self._config = config or HumanizeConfig() + self._length_buckets = length_buckets or DEFAULT_LENGTH_BUCKETS + + def humanize( + self, + text: str, + override_config: HumanizeConfig | None = None, + ) -> tuple[list[Segment], SegmentStats]: + """ + [AC-MARH-10] 将文本转换为拟人化分段。 + + Args: + text: 输入文本 + override_config: 租户覆盖配置 + + Returns: + Tuple of (segments, segment_stats) + """ + config = override_config or self._config + + if not config.enabled: + segments = [Segment( + segment_id=str(uuid.uuid4()), + text=text, + delay_after=0, + )] + stats = SegmentStats( + segment_count=1, + avg_segment_length=len(text), + humanize_strategy="disabled", + ) + return segments, stats + + raw_segments = self._split_text(text, config) + segments = [] + + for i, seg_text in enumerate(raw_segments): + is_last = i == len(raw_segments) - 1 + delay_after = 0 if is_last else self._calculate_delay(seg_text, config) + + segments.append(Segment( + segment_id=str(uuid.uuid4()), + text=seg_text, + delay_after=delay_after, + )) + + total_length = sum(len(s.text) for s in segments) + avg_length = total_length / len(segments) if segments else 0.0 + + stats = SegmentStats( + segment_count=len(segments), + avg_segment_length=avg_length, + humanize_strategy=config.length_bucket_strategy, + ) + + logger.info( + f"[AC-MARH-10] Humanized text: segments={len(segments)}, " + f"avg_length={avg_length:.1f}, strategy={config.length_bucket_strategy}" + ) + + return segments, stats + + def _split_text( + self, + text: str, + config: HumanizeConfig, + ) -> list[str]: + """切分文本。""" + if config.length_bucket_strategy == "semantic": + return self._split_semantic(text, config) + else: + return self._split_simple(text, config) + + def _split_simple( + self, + text: str, + config: HumanizeConfig, + ) -> list[str]: + """简单切分:按段落。""" + paragraphs = re.split(r'\n\s*\n', text.strip()) + segments = [] + + for para in paragraphs: + para = para.strip() + if not para: + continue + + if len(para) <= config.segment_max_length: + segments.append(para) + else: + sub_segments = self._split_by_length(para, config.segment_max_length) + segments.extend(sub_segments) + + if not segments: + segments = [text.strip()] + + return [s for s in segments if s.strip()] + + def _split_semantic( + self, + text: str, + config: HumanizeConfig, + ) -> list[str]: + """语义切分:按句子边界。""" + sentence_endings = re.compile(r'([。!?.!?]+)') + parts = sentence_endings.split(text.strip()) + + sentences = [] + current = "" + for i, part in enumerate(parts): + current += part + if sentence_endings.match(part): + sentences.append(current.strip()) + current = "" + + if current.strip(): + sentences.append(current.strip()) + + if not sentences: + sentences = [text.strip()] + + segments = [] + current_segment = "" + + for sentence in sentences: + if len(current_segment) + len(sentence) <= config.segment_max_length: + current_segment += sentence + else: + if current_segment: + segments.append(current_segment) + current_segment = sentence + + if current_segment: + segments.append(current_segment) + + return [s for s in segments if s.strip()] + + def _split_by_length( + self, + text: str, + max_length: int, + ) -> list[str]: + """按长度切分。""" + segments = [] + remaining = text + + while remaining: + if len(remaining) <= max_length: + segments.append(remaining.strip()) + break + + split_pos = max_length + for i in range(max_length - 1, max(0, max_length - 20), -1): + if remaining[i] in ',,;;:: ': + split_pos = i + 1 + break + + segments.append(remaining[:split_pos].strip()) + remaining = remaining[split_pos:] + + return [s for s in segments if s.strip()] + + def _calculate_delay( + self, + text: str, + config: HumanizeConfig, + ) -> int: + """[AC-MARH-11] 计算拟人化 delay。""" + text_length = len(text) + + for bucket in self._length_buckets: + if bucket.min_length <= text_length < bucket.max_length: + delay = bucket.delay_ms + return max(config.min_delay_ms, min(delay, config.max_delay_ms)) + + return config.min_delay_ms + + def get_config(self) -> HumanizeConfig: + """获取当前配置。""" + return self._config + + +_segment_humanizer: SegmentHumanizer | None = None + + +def get_segment_humanizer( + config: HumanizeConfig | None = None, +) -> SegmentHumanizer: + """获取或创建 SegmentHumanizer 实例。""" + global _segment_humanizer + if _segment_humanizer is None: + _segment_humanizer = SegmentHumanizer(config=config) + return _segment_humanizer diff --git a/ai-service/app/services/mid/timeout_governor.py b/ai-service/app/services/mid/timeout_governor.py new file mode 100644 index 0000000..1bd8ab7 --- /dev/null +++ b/ai-service/app/services/mid/timeout_governor.py @@ -0,0 +1,166 @@ +""" +Timeout Governor for Mid Platform. +[AC-IDMP-12] Timeout governance: per-tool <= 60s, end-to-end <= 180s. +[AC-MARH-08, AC-MARH-09] 超时口径统一:单工具 <= 60000ms,全链路 <= 180000ms。 +""" + +import asyncio +import logging +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, TypeVar + +from app.models.mid.schemas import TimeoutProfile + +logger = logging.getLogger(__name__) + +DEFAULT_PER_TOOL_TIMEOUT_MS = 30000 +DEFAULT_END_TO_END_TIMEOUT_MS = 120000 +DEFAULT_LLM_TIMEOUT_MS = 60000 + +T = TypeVar("T") + + +@dataclass +class TimeoutResult: + """Result of a timeout-governed operation.""" + success: bool + result: Any = None + error: str | None = None + duration_ms: int = 0 + timed_out: bool = False + + +class TimeoutGovernor: + """ + [AC-IDMP-12] Timeout governor for tool calls and end-to-end execution. + [AC-MARH-08, AC-MARH-09] 超时口径统一。 + + Constraints: + - Per-tool timeout: <= 60000ms (60s) + - LLM timeout: <= 120000ms (120s) + - End-to-end timeout: <= 180000ms (180s) + """ + + def __init__( + self, + per_tool_timeout_ms: int = DEFAULT_PER_TOOL_TIMEOUT_MS, + end_to_end_timeout_ms: int = DEFAULT_END_TO_END_TIMEOUT_MS, + llm_timeout_ms: int = DEFAULT_LLM_TIMEOUT_MS, + ): + self._per_tool_timeout_ms = min(per_tool_timeout_ms, 60000) + self._end_to_end_timeout_ms = min(end_to_end_timeout_ms, 180000) + self._llm_timeout_ms = min(llm_timeout_ms, 120000) + + @property + def per_tool_timeout_seconds(self) -> float: + """Per-tool timeout in seconds.""" + return self._per_tool_timeout_ms / 1000.0 + + @property + def llm_timeout_seconds(self) -> float: + """LLM call timeout in seconds.""" + return self._llm_timeout_ms / 1000.0 + + @property + def end_to_end_timeout_seconds(self) -> float: + """End-to-end timeout in seconds.""" + return self._end_to_end_timeout_ms / 1000.0 + + @property + def profile(self) -> TimeoutProfile: + """Get current timeout profile.""" + return TimeoutProfile( + per_tool_timeout_ms=self._per_tool_timeout_ms, + llm_timeout_ms=self._llm_timeout_ms, + end_to_end_timeout_ms=self._end_to_end_timeout_ms, + ) + + async def execute_with_timeout( + self, + coro: Callable[[], T], + timeout_ms: int | None = None, + operation_name: str = "operation", + ) -> TimeoutResult: + """ + [AC-MARH-08, AC-MARH-09] Execute a coroutine with timeout. + + Args: + coro: Coroutine to execute + timeout_ms: Timeout in milliseconds (defaults to per-tool timeout) + operation_name: Name for logging + + Returns: + TimeoutResult with success status and result/error + """ + import time + + timeout_ms = timeout_ms or self._per_tool_timeout_ms + timeout_seconds = timeout_ms / 1000.0 + start_time = time.time() + + try: + result = await asyncio.wait_for(coro(), timeout=timeout_seconds) + duration_ms = int((time.time() - start_time) * 1000) + + logger.debug( + f"[AC-MARH-08] {operation_name} completed in {duration_ms}ms" + ) + + return TimeoutResult( + success=True, + result=result, + duration_ms=duration_ms, + ) + + except asyncio.TimeoutError: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning( + f"[AC-MARH-08] {operation_name} timed out after {duration_ms}ms " + f"(limit: {timeout_ms}ms)" + ) + return TimeoutResult( + success=False, + error=f"Timeout after {timeout_ms}ms", + duration_ms=duration_ms, + timed_out=True, + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error( + f"[AC-MARH-08] {operation_name} failed: {e}" + ) + return TimeoutResult( + success=False, + error=str(e), + duration_ms=duration_ms, + ) + + async def execute_tool( + self, + tool_name: str, + tool_func: Callable[[], T], + ) -> TimeoutResult: + """ + [AC-MARH-08] Execute a tool call with per-tool timeout. + """ + return await self.execute_with_timeout( + coro=tool_func, + timeout_ms=self._per_tool_timeout_ms, + operation_name=f"tool:{tool_name}", + ) + + async def execute_e2e( + self, + coro: Callable[[], T], + operation_name: str = "e2e", + ) -> TimeoutResult: + """ + [AC-MARH-09] Execute with end-to-end timeout. + """ + return await self.execute_with_timeout( + coro=coro, + timeout_ms=self._end_to_end_timeout_ms, + operation_name=operation_name, + ) diff --git a/ai-service/app/services/mid/tool_call_recorder.py b/ai-service/app/services/mid/tool_call_recorder.py new file mode 100644 index 0000000..e82caa7 --- /dev/null +++ b/ai-service/app/services/mid/tool_call_recorder.py @@ -0,0 +1,324 @@ +""" +Tool Call Recorder for Mid Platform. +[AC-IDMP-15] 工具调用结构化记录 + +Reference: +- spec/intent-driven-mid-platform/openapi.provider.yaml - ToolCallTrace +- spec/intent-driven-mid-platform/requirements.md AC-IDMP-15 +""" + +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.models.mid.tool_trace import ( + ToolCallBuilder, + ToolCallStatus, + ToolCallTrace, + ToolType, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolCallStatistics: + """ + 工具调用统计信息 + """ + total_calls: int = 0 + success_calls: int = 0 + timeout_calls: int = 0 + error_calls: int = 0 + rejected_calls: int = 0 + total_duration_ms: int = 0 + avg_duration_ms: float = 0.0 + max_duration_ms: int = 0 + min_duration_ms: int = 0 + + def update(self, trace: ToolCallTrace) -> None: + self.total_calls += 1 + self.total_duration_ms += trace.duration_ms + self.avg_duration_ms = self.total_duration_ms / self.total_calls + + if trace.duration_ms > self.max_duration_ms: + self.max_duration_ms = trace.duration_ms + if self.min_duration_ms == 0 or trace.duration_ms < self.min_duration_ms: + self.min_duration_ms = trace.duration_ms + + if trace.status == ToolCallStatus.OK: + self.success_calls += 1 + elif trace.status == ToolCallStatus.TIMEOUT: + self.timeout_calls += 1 + elif trace.status == ToolCallStatus.REJECTED: + self.rejected_calls += 1 + else: + self.error_calls += 1 + + +class ToolCallRecorder: + """ + [AC-IDMP-15] 工具调用记录器 + + 功能: + 1. 记录每次工具调用的完整信息(参数摘要、耗时、状态、错误码) + 2. 支持敏感参数脱敏 + 3. 提供统计信息 + + 记录字段: + - tool_name: 工具名称 + - tool_type: 工具类型 (internal | mcp) + - registry_version: 注册表版本 + - auth_applied: 是否应用鉴权 + - duration_ms: 调用耗时 + - status: 调用状态 (ok | timeout | error | rejected) + - error_code: 错误码 + - args_digest: 参数摘要(脱敏) + - result_digest: 结果摘要 + """ + + def __init__(self, max_traces_per_session: int = 100): + self._max_traces_per_session = max_traces_per_session + self._traces: dict[str, list[ToolCallTrace]] = defaultdict(list) + self._statistics: dict[str, ToolCallStatistics] = defaultdict(ToolCallStatistics) + + def start_trace( + self, + tool_name: str, + tool_type: ToolType = ToolType.INTERNAL, + ) -> ToolCallBuilder: + """ + 开始记录一次工具调用 + + Args: + tool_name: 工具名称 + tool_type: 工具类型 + + Returns: + ToolCallBuilder: 构建器,用于逐步记录调用信息 + """ + return ToolCallBuilder( + tool_name=tool_name, + tool_type=tool_type, + ) + + def record( + self, + session_id: str, + trace: ToolCallTrace, + ) -> None: + """ + 记录一次工具调用的完整信息 + + Args: + session_id: 会话ID + trace: 工具调用追踪记录 + """ + session_traces = self._traces[session_id] + session_traces.append(trace) + + if len(session_traces) > self._max_traces_per_session: + session_traces.pop(0) + + self._statistics[trace.tool_name].update(trace) + + logger.info( + f"[AC-IDMP-15] Tool call recorded: tool={trace.tool_name}, " + f"type={trace.tool_type.value}, duration_ms={trace.duration_ms}, " + f"status={trace.status.value}, session={session_id}" + ) + + def record_success( + self, + session_id: str, + tool_name: str, + tool_type: ToolType, + duration_ms: int, + args: Any = None, + result: Any = None, + registry_version: str | None = None, + auth_applied: bool = False, + ) -> ToolCallTrace: + """ + 记录成功的工具调用(便捷方法) + """ + trace = ToolCallTrace( + tool_name=tool_name, + tool_type=tool_type, + duration_ms=duration_ms, + status=ToolCallStatus.OK, + registry_version=registry_version, + auth_applied=auth_applied, + args_digest=ToolCallTrace.compute_digest(args) if args else None, + result_digest=ToolCallTrace.compute_digest(result) if result else None, + ) + self.record(session_id, trace) + return trace + + def record_timeout( + self, + session_id: str, + tool_name: str, + tool_type: ToolType, + duration_ms: int, + args: Any = None, + registry_version: str | None = None, + auth_applied: bool = False, + ) -> ToolCallTrace: + """ + 记录超时的工具调用(便捷方法) + """ + trace = ToolCallTrace( + tool_name=tool_name, + tool_type=tool_type, + duration_ms=duration_ms, + status=ToolCallStatus.TIMEOUT, + error_code="TIMEOUT", + registry_version=registry_version, + auth_applied=auth_applied, + args_digest=ToolCallTrace.compute_digest(args) if args else None, + ) + self.record(session_id, trace) + return trace + + def record_error( + self, + session_id: str, + tool_name: str, + tool_type: ToolType, + duration_ms: int, + error_code: str, + error_message: str | None = None, + args: Any = None, + registry_version: str | None = None, + auth_applied: bool = False, + ) -> ToolCallTrace: + """ + 记录错误的工具调用(便捷方法) + """ + trace = ToolCallTrace( + tool_name=tool_name, + tool_type=tool_type, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code=error_code, + registry_version=registry_version, + auth_applied=auth_applied, + args_digest=ToolCallTrace.compute_digest(args) if args else None, + ) + self.record(session_id, trace) + return trace + + def record_rejected( + self, + session_id: str, + tool_name: str, + tool_type: ToolType, + reason: str, + args: Any = None, + registry_version: str | None = None, + ) -> ToolCallTrace: + """ + 记录被拒绝的工具调用(便捷方法) + """ + trace = ToolCallTrace( + tool_name=tool_name, + tool_type=tool_type, + duration_ms=0, + status=ToolCallStatus.REJECTED, + error_code=reason, + registry_version=registry_version, + args_digest=ToolCallTrace.compute_digest(args) if args else None, + ) + self.record(session_id, trace) + return trace + + def get_traces(self, session_id: str) -> list[ToolCallTrace]: + """ + 获取指定会话的所有工具调用记录 + + Args: + session_id: 会话ID + + Returns: + list[ToolCallTrace]: 工具调用记录列表 + """ + return self._traces.get(session_id, []) + + def get_statistics(self, tool_name: str | None = None) -> dict[str, Any]: + """ + 获取工具调用统计信息 + + Args: + tool_name: 工具名称(可选,不提供则返回所有统计) + + Returns: + dict: 统计信息 + """ + if tool_name: + stats = self._statistics.get(tool_name) + if stats: + return { + "tool_name": tool_name, + "total_calls": stats.total_calls, + "success_rate": stats.success_calls / stats.total_calls if stats.total_calls > 0 else 0, + "timeout_rate": stats.timeout_calls / stats.total_calls if stats.total_calls > 0 else 0, + "error_rate": stats.error_calls / stats.total_calls if stats.total_calls > 0 else 0, + "avg_duration_ms": stats.avg_duration_ms, + "max_duration_ms": stats.max_duration_ms, + "min_duration_ms": stats.min_duration_ms, + } + return {} + + return { + name: { + "total_calls": stats.total_calls, + "success_rate": stats.success_calls / stats.total_calls if stats.total_calls > 0 else 0, + "avg_duration_ms": stats.avg_duration_ms, + } + for name, stats in self._statistics.items() + } + + def clear_session(self, session_id: str) -> int: + """ + 清除指定会话的记录 + + Args: + session_id: 会话ID + + Returns: + int: 清除的记录数 + """ + if session_id in self._traces: + count = len(self._traces[session_id]) + del self._traces[session_id] + return count + return 0 + + def to_trace_info_format(self, session_id: str) -> list[dict[str, Any]]: + """ + 转换为 TraceInfo.tool_calls 格式 + + 用于输出到 DialogueResponse.trace.tool_calls + + Args: + session_id: 会话ID + + Returns: + list[dict]: 符合 OpenAPI 格式的工具调用列表 + """ + traces = self.get_traces(session_id) + return [trace.to_dict() for trace in traces] + + +_recorder: ToolCallRecorder | None = None + + +def get_tool_call_recorder() -> ToolCallRecorder: + """获取全局工具调用记录器实例""" + global _recorder + if _recorder is None: + _recorder = ToolCallRecorder() + return _recorder diff --git a/ai-service/app/services/mid/tool_registry.py b/ai-service/app/services/mid/tool_registry.py new file mode 100644 index 0000000..3637bd3 --- /dev/null +++ b/ai-service/app/services/mid/tool_registry.py @@ -0,0 +1,337 @@ +""" +Tool Registry for Mid Platform. +[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance. +""" + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine + +from app.models.mid.schemas import ( + ToolCallStatus, + ToolCallTrace, + ToolType, +) +from app.services.mid.timeout_governor import TimeoutGovernor + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolDefinition: + """Tool definition for registry.""" + name: str + description: str + tool_type: ToolType = ToolType.INTERNAL + version: str = "1.0.0" + enabled: bool = True + auth_required: bool = False + timeout_ms: int = 2000 + handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ToolExecutionResult: + """Tool execution result.""" + success: bool + output: Any = None + error: str | None = None + duration_ms: int = 0 + auth_applied: bool = False + registry_version: str | None = None + + +class ToolRegistry: + """ + [AC-IDMP-19] Unified tool registry for governance. + + Features: + - Tool registration with metadata + - Auth policy enforcement + - Timeout governance + - Version management + - Enable/disable control + """ + + def __init__( + self, + timeout_governor: TimeoutGovernor | None = None, + ): + self._tools: dict[str, ToolDefinition] = {} + self._timeout_governor = timeout_governor or TimeoutGovernor() + self._version = "1.0.0" + + @property + def version(self) -> str: + """Get registry version.""" + return self._version + + def register( + self, + name: str, + description: str, + handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]], + tool_type: ToolType = ToolType.INTERNAL, + version: str = "1.0.0", + auth_required: bool = False, + timeout_ms: int = 2000, + enabled: bool = True, + metadata: dict[str, Any] | None = None, + ) -> ToolDefinition: + """ + [AC-IDMP-19] Register a tool. + + Args: + name: Tool name (unique identifier) + description: Tool description + handler: Async handler function + tool_type: Tool type (internal/mcp) + version: Tool version + auth_required: Whether auth is required + timeout_ms: Tool-specific timeout + enabled: Whether tool is enabled + metadata: Additional metadata + + Returns: + ToolDefinition for the registered tool + """ + if name in self._tools: + logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}") + + tool = ToolDefinition( + name=name, + description=description, + tool_type=tool_type, + version=version, + enabled=enabled, + auth_required=auth_required, + timeout_ms=min(timeout_ms, 2000), + handler=handler, + metadata=metadata or {}, + ) + + self._tools[name] = tool + + logger.info( + f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, " + f"version={version}, auth_required={auth_required}" + ) + + return tool + + def unregister(self, name: str) -> bool: + """Unregister a tool.""" + if name in self._tools: + del self._tools[name] + logger.info(f"[AC-IDMP-19] Tool unregistered: {name}") + return True + return False + + def get_tool(self, name: str) -> ToolDefinition | None: + """Get tool definition by name.""" + return self._tools.get(name) + + def list_tools( + self, + tool_type: ToolType | None = None, + enabled_only: bool = True, + ) -> list[ToolDefinition]: + """List registered tools, optionally filtered.""" + tools = list(self._tools.values()) + + if tool_type: + tools = [t for t in tools if t.tool_type == tool_type] + + if enabled_only: + tools = [t for t in tools if t.enabled] + + return tools + + def enable_tool(self, name: str) -> bool: + """Enable a tool.""" + tool = self._tools.get(name) + if tool: + tool.enabled = True + logger.info(f"[AC-IDMP-19] Tool enabled: {name}") + return True + return False + + def disable_tool(self, name: str) -> bool: + """Disable a tool.""" + tool = self._tools.get(name) + if tool: + tool.enabled = False + logger.info(f"[AC-IDMP-19] Tool disabled: {name}") + return True + return False + + async def execute( + self, + tool_name: str, + args: dict[str, Any], + auth_context: dict[str, Any] | None = None, + ) -> ToolExecutionResult: + """ + [AC-IDMP-19] Execute a tool with governance. + + Args: + tool_name: Tool name to execute + args: Tool arguments + auth_context: Authentication context + + Returns: + ToolExecutionResult with output and metadata + """ + start_time = time.time() + + tool = self._tools.get(tool_name) + if not tool: + logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}") + return ToolExecutionResult( + success=False, + error=f"Tool not found: {tool_name}", + duration_ms=0, + ) + + if not tool.enabled: + logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}") + return ToolExecutionResult( + success=False, + error=f"Tool disabled: {tool_name}", + duration_ms=0, + registry_version=tool.version, + ) + + auth_applied = False + if tool.auth_required: + if not auth_context: + logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}") + return ToolExecutionResult( + success=False, + error="Authentication required", + duration_ms=int((time.time() - start_time) * 1000), + auth_applied=False, + registry_version=tool.version, + ) + auth_applied = True + + try: + timeout_seconds = tool.timeout_ms / 1000.0 + + result = await asyncio.wait_for( + tool.handler(**args) if tool.handler else asyncio.sleep(0), + timeout=timeout_seconds, + ) + + duration_ms = int((time.time() - start_time) * 1000) + + logger.info( + f"[AC-IDMP-19] Tool executed: name={tool_name}, " + f"duration_ms={duration_ms}, success=True" + ) + + return ToolExecutionResult( + success=True, + output=result, + duration_ms=duration_ms, + auth_applied=auth_applied, + registry_version=tool.version, + ) + + except asyncio.TimeoutError: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning( + f"[AC-IDMP-19] Tool timeout: name={tool_name}, " + f"duration_ms={duration_ms}" + ) + return ToolExecutionResult( + success=False, + error=f"Tool timeout after {tool.timeout_ms}ms", + duration_ms=duration_ms, + auth_applied=auth_applied, + registry_version=tool.version, + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error( + f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}" + ) + return ToolExecutionResult( + success=False, + error=str(e), + duration_ms=duration_ms, + auth_applied=auth_applied, + registry_version=tool.version, + ) + + def create_trace( + self, + tool_name: str, + result: ToolExecutionResult, + args_digest: str | None = None, + ) -> ToolCallTrace: + """ + [AC-IDMP-19] Create ToolCallTrace from execution result. + """ + tool = self._tools.get(tool_name) + + return ToolCallTrace( + tool_name=tool_name, + tool_type=tool.tool_type if tool else ToolType.INTERNAL, + registry_version=result.registry_version, + auth_applied=result.auth_applied, + duration_ms=result.duration_ms, + status=ToolCallStatus.OK if result.success else ( + ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower() + else ToolCallStatus.ERROR + ), + error_code=result.error if not result.success else None, + args_digest=args_digest, + result_digest=str(result.output)[:100] if result.output else None, + ) + + def get_governance_report(self) -> dict[str, Any]: + """Get governance report for all tools.""" + return { + "registry_version": self._version, + "total_tools": len(self._tools), + "enabled_tools": sum(1 for t in self._tools.values() if t.enabled), + "disabled_tools": sum(1 for t in self._tools.values() if not t.enabled), + "auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required), + "mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP), + "internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL), + "tools": [ + { + "name": t.name, + "type": t.tool_type.value, + "version": t.version, + "enabled": t.enabled, + "auth_required": t.auth_required, + "timeout_ms": t.timeout_ms, + } + for t in self._tools.values() + ], + } + + +_registry: ToolRegistry | None = None + + +def get_tool_registry() -> ToolRegistry: + """Get global tool registry instance.""" + global _registry + if _registry is None: + _registry = ToolRegistry() + return _registry + + +def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry: + """Initialize and return tool registry.""" + global _registry + _registry = ToolRegistry(timeout_governor=timeout_governor) + return _registry diff --git a/ai-service/app/services/mid/trace_logger.py b/ai-service/app/services/mid/trace_logger.py new file mode 100644 index 0000000..35c301c --- /dev/null +++ b/ai-service/app/services/mid/trace_logger.py @@ -0,0 +1,269 @@ +""" +Trace Logger for Mid Platform. +[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace collection and audit logging. + +Audit Fields: +- session_id, request_id, generation_id +- mode, intent, tool_calls +- guardrail_triggered, guardrail_rule_id +- interrupt_consumed +""" + +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import Any + +from app.models.mid.schemas import ( + ExecutionMode, + ToolCallTrace, + TraceInfo, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class AuditRecord: + """[AC-MARH-12] Audit record for database persistence.""" + tenant_id: str + session_id: str + request_id: str + generation_id: str + mode: ExecutionMode + intent: str | None = None + tool_calls: list[dict[str, Any]] = field(default_factory=list) + guardrail_triggered: bool = False + guardrail_rule_id: str | None = None + interrupt_consumed: bool = False + fallback_reason_code: str | None = None + react_iterations: int = 0 + latency_ms: int = 0 + created_at: str | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "tenant_id": self.tenant_id, + "session_id": self.session_id, + "request_id": self.request_id, + "generation_id": self.generation_id, + "mode": self.mode.value, + "intent": self.intent, + "tool_calls": self.tool_calls, + "guardrail_triggered": self.guardrail_triggered, + "guardrail_rule_id": self.guardrail_rule_id, + "interrupt_consumed": self.interrupt_consumed, + "fallback_reason_code": self.fallback_reason_code, + "react_iterations": self.react_iterations, + "latency_ms": self.latency_ms, + "created_at": self.created_at, + } + + +class TraceLogger: + """ + [AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace logger for observability and audit. + + Features: + - Request-scoped trace context + - Tool call tracing + - Guardrail event logging with rule_id + - Interrupt consumption tracking + - Audit record generation + """ + + def __init__(self): + self._traces: dict[str, TraceInfo] = {} + self._audit_records: list[AuditRecord] = [] + + def start_trace( + self, + tenant_id: str, + session_id: str, + request_id: str | None = None, + generation_id: str | None = None, + ) -> TraceInfo: + """ + [AC-MARH-12] Start a new trace context. + + Args: + tenant_id: Tenant ID + session_id: Session ID + request_id: Request ID (auto-generated if not provided) + generation_id: Generation ID for interrupt handling + + Returns: + TraceInfo for the new trace + """ + request_id = request_id or str(uuid.uuid4()) + generation_id = generation_id or str(uuid.uuid4()) + + trace = TraceInfo( + mode=ExecutionMode.AGENT, + request_id=request_id, + generation_id=generation_id, + ) + + self._traces[request_id] = trace + + logger.info( + f"[AC-MARH-12] Trace started: request_id={request_id}, " + f"session_id={session_id}, generation_id={generation_id}" + ) + + return trace + + def get_trace(self, request_id: str) -> TraceInfo | None: + """Get trace by request ID.""" + return self._traces.get(request_id) + + def update_trace( + self, + request_id: str, + mode: ExecutionMode | None = None, + intent: str | None = None, + guardrail_triggered: bool | None = None, + guardrail_rule_id: str | None = None, + interrupt_consumed: bool | None = None, + fallback_reason_code: str | None = None, + react_iterations: int | None = None, + tool_calls: list[ToolCallTrace] | None = None, + ) -> TraceInfo | None: + """ + [AC-MARH-02, AC-MARH-03, AC-MARH-12] Update trace with execution details. + """ + trace = self._traces.get(request_id) + if not trace: + logger.warning(f"[AC-MARH-12] Trace not found: {request_id}") + return None + + if mode is not None: + trace.mode = mode + if intent is not None: + trace.intent = intent + if guardrail_triggered is not None: + trace.guardrail_triggered = guardrail_triggered + if guardrail_rule_id is not None: + trace.guardrail_rule_id = guardrail_rule_id + if interrupt_consumed is not None: + trace.interrupt_consumed = interrupt_consumed + if fallback_reason_code is not None: + trace.fallback_reason_code = fallback_reason_code + if react_iterations is not None: + trace.react_iterations = react_iterations + if tool_calls is not None: + trace.tool_calls = tool_calls + + return trace + + def add_tool_call( + self, + request_id: str, + tool_call: ToolCallTrace, + ) -> None: + """ + [AC-MARH-12] Add tool call trace to request. + """ + trace = self._traces.get(request_id) + if not trace: + logger.warning(f"[AC-MARH-12] Trace not found for tool call: {request_id}") + return + + if trace.tool_calls is None: + trace.tool_calls = [] + + trace.tool_calls.append(tool_call) + + if trace.tools_used is None: + trace.tools_used = [] + + if tool_call.tool_name not in trace.tools_used: + trace.tools_used.append(tool_call.tool_name) + + logger.debug( + f"[AC-MARH-12] Tool call recorded: request_id={request_id}, " + f"tool={tool_call.tool_name}, status={tool_call.status.value}" + ) + + def end_trace( + self, + request_id: str, + tenant_id: str, + session_id: str, + latency_ms: int, + ) -> AuditRecord: + """ + [AC-MARH-12] End trace and create audit record. + """ + trace = self._traces.get(request_id) + if not trace: + logger.warning(f"[AC-MARH-12] Trace not found for end: {request_id}") + return AuditRecord( + tenant_id=tenant_id, + session_id=session_id, + request_id=request_id, + generation_id=str(uuid.uuid4()), + mode=ExecutionMode.FIXED, + latency_ms=latency_ms, + ) + + audit = AuditRecord( + tenant_id=tenant_id, + session_id=session_id, + request_id=request_id, + generation_id=trace.generation_id or str(uuid.uuid4()), + mode=trace.mode, + intent=trace.intent, + tool_calls=[tc.model_dump() for tc in trace.tool_calls] if trace.tool_calls else [], + guardrail_triggered=trace.guardrail_triggered or False, + guardrail_rule_id=trace.guardrail_rule_id, + interrupt_consumed=trace.interrupt_consumed or False, + fallback_reason_code=trace.fallback_reason_code, + react_iterations=trace.react_iterations or 0, + latency_ms=latency_ms, + created_at=time.strftime("%Y-%m-%d %H:%M:%S"), + ) + + self._audit_records.append(audit) + + if request_id in self._traces: + del self._traces[request_id] + + logger.info( + f"[AC-MARH-12] Trace ended: request_id={request_id}, " + f"mode={trace.mode.value}, latency_ms={latency_ms}" + ) + + return audit + + def get_audit_records( + self, + tenant_id: str, + session_id: str | None = None, + limit: int = 100, + ) -> list[AuditRecord]: + """Get audit records for a tenant/session.""" + records = [ + r for r in self._audit_records + if r.tenant_id == tenant_id + ] + + if session_id: + records = [r for r in records if r.session_id == session_id] + + return records[-limit:] + + def clear_audit_records(self, tenant_id: str | None = None) -> int: + """Clear audit records, optionally filtered by tenant.""" + if tenant_id: + original_count = len(self._audit_records) + self._audit_records = [ + r for r in self._audit_records + if r.tenant_id != tenant_id + ] + return original_count - len(self._audit_records) + else: + count = len(self._audit_records) + self._audit_records = [] + return count