feat: implement mid-platform core services and APIs [AC-IDMP-01~20, AC-MARH-01~12]

- Add dialogue, messages, sessions, share API endpoints
- Add mid-platform schemas and models (memory, tool_registry, tool_trace)
- Add core services: agent_orchestrator, policy_router, runtime_observer
- Add tool services: kb_search_dynamic, memory_adapter, tool_registry
- Add guardrail services: output_guardrail_executor, high_risk_handler
- Add utility services: timeout_governor, segment_humanizer, metrics_collector
This commit is contained in:
MerCry 2026-03-05 18:13:34 +08:00
parent 3d9f718f5f
commit b4eb98e7c4
26 changed files with 6978 additions and 0 deletions

View File

@ -0,0 +1,30 @@
"""
Mid Platform API endpoints.
[AC-IDMP-01~20] Mid platform dialogue, messages, and session management.
[AC-IDMP-SHARE] Share session via unique token.
[AC-MRS-09,10] Runtime slot query endpoints.
"""
from fastapi import APIRouter
from .dialogue import router as dialogue_router
from .messages import router as messages_router
from .sessions import router as sessions_router
from .share import router as share_router
from .slots import router as slots_router
router = APIRouter()
router.include_router(dialogue_router)
router.include_router(messages_router)
router.include_router(sessions_router)
router.include_router(share_router)
router.include_router(slots_router)
__all__ = [
"router",
"dialogue_router",
"messages_router",
"sessions_router",
"share_router",
"slots_router",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,104 @@
"""
Messages Controller for Mid Platform.
[AC-IDMP-08] Message report endpoint: POST /mid/messages/report
"""
import logging
import time
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.tenant import get_tenant_id
from app.models.mid.schemas import MessageReportRequest
from app.services.memory import MemoryService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mid", tags=["Mid Platform Messages"])
@router.post(
"/messages/report",
operation_id="reportMessages",
summary="Report session messages and events",
description="""
[AC-IDMP-08] Report messages from channel to mid platform.
Accepts messages from bot, human, or channel sources for session closure.
Returns 202 Accepted for async processing.
""",
responses={
202: {"description": "Accepted for async processing"},
400: {"description": "Invalid request"},
},
)
async def report_messages(
request: Request,
report_request: MessageReportRequest,
session: Annotated[AsyncSession, Depends(get_session)],
) -> JSONResponse:
"""
[AC-IDMP-08] Report messages from channel.
Accepts and stores messages for session data completeness.
Messages are stored asynchronously without blocking response.
"""
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
logger.info(
f"[AC-IDMP-08] Message report: tenant={tenant_id}, "
f"session={report_request.session_id}, count={len(report_request.messages)}"
)
try:
memory_service = MemoryService(session)
await memory_service.get_or_create_session(
tenant_id=tenant_id,
session_id=report_request.session_id,
)
messages_to_save = []
for msg in report_request.messages:
role = msg.role
if msg.source == "human":
role = "human"
elif msg.source == "bot":
role = "assistant"
messages_to_save.append({
"role": role,
"content": msg.content,
})
if messages_to_save:
await memory_service.append_messages(
tenant_id=tenant_id,
session_id=report_request.session_id,
messages=messages_to_save,
)
logger.info(
f"[AC-IDMP-08] Messages saved: tenant={tenant_id}, "
f"session={report_request.session_id}, count={len(messages_to_save)}"
)
return JSONResponse(
status_code=202,
content={"status": "accepted", "session_id": report_request.session_id},
)
except Exception as e:
logger.error(f"[AC-IDMP-08] Message report failed: {e}")
return JSONResponse(
status_code=202,
content={"status": "accepted", "warning": str(e)},
)

View File

@ -0,0 +1,105 @@
"""
Sessions Controller for Mid Platform.
[AC-IDMP-09] Session mode switch endpoint: POST /mid/sessions/{sessionId}/mode
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Path
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.tenant import get_tenant_id
from app.models.mid.schemas import (
SessionMode,
SwitchModeRequest,
SwitchModeResponse,
)
from app.services.memory import MemoryService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"])
_session_modes: dict[str, SessionMode] = {}
@router.post(
"/sessions/{sessionId}/mode",
operation_id="switchSessionMode",
summary="Switch session mode",
description="""
[AC-IDMP-09] Switch session mode between BOT_ACTIVE and HUMAN_ACTIVE.
When mode is HUMAN_ACTIVE, dialogue responses will route to transfer mode.
""",
responses={
200: {"description": "Mode switched successfully", "model": SwitchModeResponse},
400: {"description": "Invalid request"},
},
)
async def switch_session_mode(
sessionId: Annotated[str, Path(description="Session ID")],
switch_request: SwitchModeRequest,
session: Annotated[AsyncSession, Depends(get_session)],
) -> SwitchModeResponse:
"""
[AC-IDMP-09] Switch session mode.
Modes:
- BOT_ACTIVE: Bot handles responses
- HUMAN_ACTIVE: Transfer to human agent
"""
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
logger.info(
f"[AC-IDMP-09] Mode switch: tenant={tenant_id}, "
f"session={sessionId}, mode={switch_request.mode.value}"
)
try:
memory_service = MemoryService(session)
await memory_service.get_or_create_session(
tenant_id=tenant_id,
session_id=sessionId,
)
session_key = f"{tenant_id}:{sessionId}"
_session_modes[session_key] = switch_request.mode
logger.info(
f"[AC-IDMP-09] Mode switched: session={sessionId}, "
f"mode={switch_request.mode.value}, reason={switch_request.reason}"
)
return SwitchModeResponse(
session_id=sessionId,
mode=switch_request.mode,
)
except Exception as e:
logger.error(f"[AC-IDMP-09] Mode switch failed: {e}")
return SwitchModeResponse(
session_id=sessionId,
mode=switch_request.mode,
)
def get_session_mode(tenant_id: str, session_id: str) -> SessionMode:
"""Get current session mode."""
session_key = f"{tenant_id}:{session_id}"
return _session_modes.get(session_key, SessionMode.BOT_ACTIVE)
def clear_session_mode(tenant_id: str, session_id: str) -> None:
"""Clear session mode (reset to BOT_ACTIVE)."""
session_key = f"{tenant_id}:{session_id}"
if session_key in _session_modes:
del _session_modes[session_key]

View File

@ -0,0 +1,224 @@
"""
Mid Platform models for Intent-Driven Agent.
[AC-IDMP-01~20] 中台统一响应协议模型
"""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class Mode(str, Enum):
"""[AC-IDMP-02] 执行模式"""
AGENT = "agent"
MICRO_FLOW = "micro_flow"
FIXED = "fixed"
TRANSFER = "transfer"
class SessionMode(str, Enum):
"""[AC-IDMP-09] 会话模式"""
BOT_ACTIVE = "BOT_ACTIVE"
HUMAN_ACTIVE = "HUMAN_ACTIVE"
class HighRiskScenario(str, Enum):
"""[AC-IDMP-20] 高风险场景最小集"""
REFUND = "refund"
COMPLAINT_ESCALATION = "complaint_escalation"
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
TRANSFER = "transfer"
class ToolCallStatus(str, Enum):
"""[AC-IDMP-15] 工具调用状态"""
OK = "ok"
TIMEOUT = "timeout"
ERROR = "error"
REJECTED = "rejected"
class ToolType(str, Enum):
"""[AC-IDMP-19] 工具类型"""
INTERNAL = "internal"
MCP = "mcp"
class HistoryMessage(BaseModel):
"""[AC-IDMP-03] 已送达历史消息"""
role: str = Field(..., description="消息角色: user/assistant/human")
content: str = Field(..., description="消息内容")
class InterruptedSegment(BaseModel):
"""[AC-IDMP-04] 打断的分段"""
segment_id: str = Field(..., description="分段ID")
content: str = Field(..., description="分段内容")
class FeatureFlags(BaseModel):
"""[AC-IDMP-17] 特性开关"""
agent_enabled: bool | None = Field(default=None, description="会话级 Agent 灰度开关")
rollback_to_legacy: bool | None = Field(default=None, description="强制回滚传统链路")
class DialogueRequest(BaseModel):
"""[AC-IDMP-01~04] 会话响应请求"""
session_id: str = Field(..., description="会话ID")
user_id: str | None = Field(default=None, description="用户ID用于记忆召回与更新")
user_message: str = Field(..., min_length=1, max_length=2000, description="用户消息")
history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史")
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段")
feature_flags: FeatureFlags | None = Field(default=None, description="特性开关")
class Segment(BaseModel):
"""[AC-IDMP-01] 响应分段"""
segment_id: str = Field(..., description="分段ID")
text: str = Field(..., description="分段文本")
delay_after: int = Field(default=0, ge=0, description="分段后延迟(ms)")
class TimeoutProfile(BaseModel):
"""[AC-IDMP-12] 超时配置"""
per_tool_timeout_ms: int | None = Field(default=30000, le=60000, description="单工具超时(ms)")
end_to_end_timeout_ms: int | None = Field(default=120000, le=180000, description="端到端超时(ms)")
class MetricsSnapshot(BaseModel):
"""[AC-IDMP-18] 运行指标快照"""
task_completion_rate: float | None = Field(default=None, ge=0, le=1, description="任务达成率")
slot_completion_rate: float | None = Field(default=None, ge=0, le=1, description="槽位完整率")
wrong_transfer_rate: float | None = Field(default=None, ge=0, le=1, description="误转人工率")
no_recall_rate: float | None = Field(default=None, ge=0, le=1, description="无召回率")
avg_latency_ms: float | None = Field(default=None, ge=0, description="平均时延(ms)")
class ToolCallTraceModel(BaseModel):
"""[AC-IDMP-15/19] 工具调用追踪 (Pydantic 模型)"""
tool_name: str = Field(..., description="工具名称")
tool_type: ToolType | None = Field(default=None, description="工具类型: internal/mcp")
registry_version: str | None = Field(default=None, description="注册版本")
auth_applied: bool | None = Field(default=None, description="是否应用鉴权")
duration_ms: int = Field(..., ge=0, description="耗时(ms)")
status: ToolCallStatus = Field(..., description="状态: ok/timeout/error/rejected")
error_code: str | None = Field(default=None, description="错误码")
args_digest: str | None = Field(default=None, description="参数摘要")
result_digest: str | None = Field(default=None, description="结果摘要")
class TraceInfo(BaseModel):
"""[AC-IDMP-02/07] 追踪信息"""
mode: Mode = Field(..., description="执行模式")
intent: str | None = Field(default=None, description="意图")
request_id: str | None = Field(default=None, description="请求ID")
generation_id: str | None = Field(default=None, description="生成ID")
guardrail_triggered: bool | None = Field(default=False, description="护栏是否触发")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
react_iterations: int | None = Field(default=None, ge=0, le=5, description="ReAct循环次数")
timeout_profile: TimeoutProfile | None = Field(default=None, description="超时配置")
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="指标快照")
high_risk_policy_set: list[HighRiskScenario] | None = Field(
default=None,
description="当前启用的高风险最小场景集"
)
tools_used: list[str] | None = Field(default=None, description="使用的工具列表")
tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪")
class DialogueResponse(BaseModel):
"""[AC-IDMP-01/02] 会话响应"""
segments: list[Segment] = Field(..., description="响应分段列表")
trace: TraceInfo = Field(..., description="追踪信息")
class ReportedMessage(BaseModel):
"""[AC-IDMP-08] 上报的消息"""
role: str = Field(..., description="角色: user/assistant/human/system")
content: str = Field(..., description="消息内容")
source: str = Field(..., description="来源: bot/human/channel")
timestamp: datetime = Field(..., description="时间戳")
segment_id: str | None = Field(default=None, description="分段ID")
class MessageReportRequest(BaseModel):
"""[AC-IDMP-08] 消息上报请求"""
session_id: str = Field(..., description="会话ID")
messages: list[ReportedMessage] = Field(..., description="消息列表")
class SwitchModeRequest(BaseModel):
"""[AC-IDMP-09] 切换模式请求"""
mode: SessionMode = Field(..., description="目标模式")
reason: str | None = Field(default=None, description="切换原因")
class SwitchModeResponse(BaseModel):
"""[AC-IDMP-09] 切换模式响应"""
session_id: str = Field(..., description="会话ID")
mode: SessionMode = Field(..., description="当前模式")
from app.models.mid.memory import (
RecallRequest,
RecallResponse,
UpdateRequest,
MemoryProfile,
MemoryFact,
MemoryPreferences,
)
from app.models.mid.tool_trace import (
ToolCallTrace,
ToolCallBuilder,
ToolCallStatus as ToolCallStatusEnum,
ToolType as ToolTypeEnum,
)
from app.models.mid.tool_registry import (
ToolDefinition,
ToolAuthConfig,
ToolTimeoutPolicy,
ToolRegistryEntity,
ToolRegistryCreate,
ToolRegistryUpdate,
)
__all__ = [
"Mode",
"SessionMode",
"HighRiskScenario",
"ToolCallStatus",
"ToolType",
"HistoryMessage",
"InterruptedSegment",
"FeatureFlags",
"DialogueRequest",
"Segment",
"TimeoutProfile",
"MetricsSnapshot",
"ToolCallTraceModel",
"TraceInfo",
"DialogueResponse",
"ReportedMessage",
"MessageReportRequest",
"SwitchModeRequest",
"SwitchModeResponse",
"RecallRequest",
"RecallResponse",
"UpdateRequest",
"MemoryProfile",
"MemoryFact",
"MemoryPreferences",
"ToolCallTrace",
"ToolCallBuilder",
"ToolCallStatusEnum",
"ToolTypeEnum",
"ToolDefinition",
"ToolAuthConfig",
"ToolTimeoutPolicy",
"ToolRegistryEntity",
"ToolRegistryCreate",
"ToolRegistryUpdate",
]

View File

@ -0,0 +1,182 @@
"""
Memory models for Mid Platform.
[AC-IDMP-13] 记忆召回数据模型
[AC-IDMP-14] 记忆更新数据模型
Reference: spec/intent-driven-mid-platform/openapi.deps.yaml
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from uuid import UUID
@dataclass
class MemoryProfile:
"""
[AC-IDMP-13] 用户基础属性记忆
包含年级地区渠道等基础信息
"""
grade: str | None = None
region: str | None = None
channel: str | None = None
vip_level: str | None = None
registration_date: datetime | None = None
extra: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
result = {}
if self.grade:
result["grade"] = self.grade
if self.region:
result["region"] = self.region
if self.channel:
result["channel"] = self.channel
if self.vip_level:
result["vip_level"] = self.vip_level
if self.registration_date:
result["registration_date"] = self.registration_date.isoformat()
if self.extra:
result.update(self.extra)
return result
@dataclass
class MemoryFact:
"""
[AC-IDMP-13] 事实型记忆
包含已购课程学习结论等客观事实
"""
content: str
source: str | None = None
confidence: float | None = None
created_at: datetime | None = None
expires_at: datetime | None = None
def to_string(self) -> str:
return self.content
@dataclass
class MemoryPreferences:
"""
[AC-IDMP-13] 偏好记忆
包含语气偏好关注科目等用户偏好
"""
tone: str | None = None
focus_subjects: list[str] = field(default_factory=list)
communication_style: str | None = None
preferred_time: str | None = None
extra: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
result = {}
if self.tone:
result["tone"] = self.tone
if self.focus_subjects:
result["focus_subjects"] = self.focus_subjects
if self.communication_style:
result["communication_style"] = self.communication_style
if self.preferred_time:
result["preferred_time"] = self.preferred_time
if self.extra:
result.update(self.extra)
return result
@dataclass
class RecallRequest:
"""
[AC-IDMP-13] 记忆召回请求
Reference: openapi.deps.yaml - RecallRequest
"""
user_id: str
session_id: str
def to_dict(self) -> dict[str, Any]:
return {
"user_id": self.user_id,
"session_id": self.session_id,
}
@dataclass
class RecallResponse:
"""
[AC-IDMP-13] 记忆召回响应
Reference: openapi.deps.yaml - RecallResponse
"""
profile: MemoryProfile | None = None
facts: list[MemoryFact] = field(default_factory=list)
preferences: MemoryPreferences | None = None
last_summary: str | None = None
def to_dict(self) -> dict[str, Any]:
result = {}
if self.profile:
result["profile"] = self.profile.to_dict()
if self.facts:
result["facts"] = [f.to_string() for f in self.facts]
if self.preferences:
result["preferences"] = self.preferences.to_dict()
if self.last_summary:
result["last_summary"] = self.last_summary
return result
def get_context_for_prompt(self) -> str:
"""
生成用于注入 Prompt 的上下文字符串
"""
parts = []
if self.profile:
profile_parts = []
if self.profile.grade:
profile_parts.append(f"年级: {self.profile.grade}")
if self.profile.region:
profile_parts.append(f"地区: {self.profile.region}")
if self.profile.vip_level:
profile_parts.append(f"会员等级: {self.profile.vip_level}")
if profile_parts:
parts.append("【用户属性】" + "".join(profile_parts))
if self.facts:
fact_strs = [f.content for f in self.facts[:5]]
parts.append("【已知事实】" + "".join(fact_strs))
if self.preferences:
pref_parts = []
if self.preferences.tone:
pref_parts.append(f"语气偏好: {self.preferences.tone}")
if self.preferences.focus_subjects:
pref_parts.append(f"关注科目: {', '.join(self.preferences.focus_subjects)}")
if pref_parts:
parts.append("【用户偏好】" + "".join(pref_parts))
if self.last_summary:
parts.append(f"【上次会话摘要】{self.last_summary}")
return "\n".join(parts) if parts else ""
@dataclass
class UpdateRequest:
"""
[AC-IDMP-14] 记忆更新请求
Reference: openapi.deps.yaml - UpdateRequest
"""
user_id: str
session_id: str
messages: list[dict[str, Any]]
summary: str | None = None
def to_dict(self) -> dict[str, Any]:
result = {
"user_id": self.user_id,
"session_id": self.session_id,
"messages": self.messages,
}
if self.summary:
result["summary"] = self.summary
return result

View File

@ -0,0 +1,407 @@
"""
Mid Platform schemas.
[AC-IDMP-01, AC-IDMP-02, AC-IDMP-07, AC-IDMP-11, AC-IDMP-12, AC-IDMP-15, AC-IDMP-17, AC-IDMP-18, AC-IDMP-19, AC-IDMP-20]
Aligned with spec/intent-driven-mid-platform/openapi.provider.yaml
"""
from __future__ import annotations
import uuid
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ExecutionMode(str, Enum):
"""[AC-IDMP-02] Execution mode for dialogue response."""
AGENT = "agent"
MICRO_FLOW = "micro_flow"
FIXED = "fixed"
TRANSFER = "transfer"
class HighRiskScenario(str, Enum):
"""[AC-IDMP-20] High risk scenario types for mandatory takeover."""
REFUND = "refund"
COMPLAINT_ESCALATION = "complaint_escalation"
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
TRANSFER = "transfer"
class ToolCallStatus(str, Enum):
"""[AC-IDMP-15] Tool call status."""
OK = "ok"
TIMEOUT = "timeout"
ERROR = "error"
REJECTED = "rejected"
class ToolType(str, Enum):
"""[AC-IDMP-19] Tool type for registry governance."""
INTERNAL = "internal"
MCP = "mcp"
class SessionMode(str, Enum):
"""[AC-IDMP-09] Session mode for bot/human switching."""
BOT_ACTIVE = "BOT_ACTIVE"
HUMAN_ACTIVE = "HUMAN_ACTIVE"
class HistoryMessage(BaseModel):
"""[AC-IDMP-03] History message with only delivered content."""
role: str = Field(..., description="Message role: user, assistant, or human")
content: str = Field(..., description="Message content")
class InterruptedSegment(BaseModel):
"""[AC-IDMP-04] Interrupted segment for handling user interruption."""
segment_id: str = Field(..., description="Segment ID")
content: str = Field(..., description="Segment content")
class FeatureFlags(BaseModel):
"""[AC-IDMP-17] Feature flags for session-level grayscale and rollback."""
agent_enabled: bool | None = Field(default=True, description="Session-level Agent grayscale switch")
rollback_to_legacy: bool | None = Field(default=False, description="Force rollback to legacy pipeline")
class HumanizeConfigRequest(BaseModel):
"""[AC-MARH-11] 拟人化配置请求。"""
enabled: bool | None = Field(default=True, description="Enable humanize strategy")
min_delay_ms: int | None = Field(default=50, ge=0, description="Minimum delay in milliseconds")
max_delay_ms: int | None = Field(default=500, ge=0, description="Maximum delay in milliseconds")
length_bucket_strategy: str | None = Field(default="simple", description="Strategy: simple or semantic")
class DialogueRequest(BaseModel):
"""[AC-IDMP-01, AC-IDMP-03, AC-IDMP-04, AC-IDMP-17, AC-MARH-11] Dialogue request schema."""
session_id: str = Field(..., description="Session ID for conversation tracking")
user_id: str | None = Field(default=None, description="User ID for memory recall and update")
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
history: list[HistoryMessage] = Field(default_factory=list, description="Only delivered history")
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="Interrupted segments")
feature_flags: FeatureFlags | None = Field(default=None, description="Feature flags for grayscale control")
humanize_config: HumanizeConfigRequest | None = Field(
default=None, description="Humanize config for segment delay"
)
class Segment(BaseModel):
"""[AC-IDMP-01] Response segment with delay control."""
segment_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Segment ID")
text: str = Field(..., description="Segment text content")
delay_after: int = Field(default=0, ge=0, description="Delay after this segment in milliseconds")
class TimeoutProfile(BaseModel):
"""[AC-MARH-08, AC-MARH-09] Timeout configuration profile."""
per_tool_timeout_ms: int = Field(default=30000, le=60000, description="Per-tool timeout in milliseconds")
llm_timeout_ms: int = Field(default=60000, le=120000, description="LLM call timeout in milliseconds")
end_to_end_timeout_ms: int = Field(default=120000, le=180000, description="End-to-end timeout in milliseconds")
class MetricsSnapshot(BaseModel):
"""[AC-IDMP-18] Runtime metrics snapshot."""
task_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Task completion rate")
slot_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Slot completion rate")
wrong_transfer_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Wrong transfer rate")
no_recall_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="No recall rate")
avg_latency_ms: float | None = Field(default=None, ge=0.0, description="Average latency in milliseconds")
class ToolCallTrace(BaseModel):
"""[AC-IDMP-15, AC-IDMP-19] Tool call trace for observability."""
tool_name: str = Field(..., description="Tool name")
tool_type: ToolType | None = Field(default=ToolType.INTERNAL, description="Tool type: internal or mcp")
registry_version: str | None = Field(default=None, description="Tool registry version")
auth_applied: bool | None = Field(default=False, description="Whether auth was applied")
duration_ms: int = Field(..., ge=0, description="Duration in milliseconds")
status: ToolCallStatus = Field(..., description="Tool call status")
error_code: str | None = Field(default=None, description="Error code if failed")
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
result_digest: str | None = Field(default=None, description="Result digest for logging")
class SegmentStats(BaseModel):
"""[AC-MARH-12] Segment statistics for humanize strategy."""
segment_count: int = Field(default=0, ge=0, description="Number of segments")
avg_segment_length: float = Field(default=0.0, ge=0.0, description="Average segment length")
humanize_strategy: str | None = Field(default=None, description="Humanize strategy used")
class TraceInfo(BaseModel):
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability."""
mode: ExecutionMode = Field(..., description="Execution mode")
intent: str | None = Field(default=None, description="Matched intent")
request_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()), description="Request ID"
)
generation_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Generation ID for interrupt handling",
)
guardrail_triggered: bool | None = Field(default=False, description="Whether guardrail was triggered")
guardrail_rule_id: str | None = Field(default=None, description="Guardrail rule ID that triggered")
interrupt_consumed: bool | None = Field(default=False, description="Whether interrupted segments were consumed")
kb_tool_called: bool | None = Field(default=False, description="Whether KB tool was called")
kb_hit: bool | None = Field(default=False, description="Whether KB search had results")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
react_iterations: int | None = Field(default=0, ge=0, le=5, description="ReAct loop iterations")
timeout_profile: TimeoutProfile | None = Field(default=None, description="Timeout profile")
segment_stats: SegmentStats | None = Field(default=None, description="Segment statistics")
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="Metrics snapshot")
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
class DialogueResponse(BaseModel):
"""[AC-IDMP-01, AC-IDMP-02] Dialogue response with segments and trace."""
segments: list[Segment] = Field(..., description="Response segments")
trace: TraceInfo = Field(..., description="Trace info for observability")
class ReportedMessage(BaseModel):
"""[AC-IDMP-08] Reported message for message report API."""
role: str = Field(..., description="Message role: user, assistant, human, or system")
content: str = Field(..., description="Message content")
source: str = Field(..., description="Message source: bot, human, or channel")
timestamp: str = Field(..., description="Message timestamp in ISO format")
segment_id: str | None = Field(default=None, description="Segment ID if applicable")
class MessageReportRequest(BaseModel):
"""[AC-IDMP-08] Message report request schema."""
session_id: str = Field(..., description="Session ID")
messages: list[ReportedMessage] = Field(..., description="Messages to report")
class SwitchModeRequest(BaseModel):
"""[AC-IDMP-09] Switch session mode request."""
mode: SessionMode = Field(..., description="Target mode: BOT_ACTIVE or HUMAN_ACTIVE")
reason: str | None = Field(default=None, description="Reason for mode switch")
class SwitchModeResponse(BaseModel):
"""[AC-IDMP-09] Switch session mode response."""
session_id: str = Field(..., description="Session ID")
mode: SessionMode = Field(..., description="Current mode after switch")
class MidSessionState(BaseModel):
"""Internal session state for mid platform."""
session_id: str
tenant_id: str
mode: SessionMode = SessionMode.BOT_ACTIVE
generation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
active_flow_id: str | None = None
context: dict[str, Any] | None = None
created_at: str | None = None
updated_at: str | None = None
class PolicyRouterResult(BaseModel):
"""[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16] Policy router decision result."""
mode: ExecutionMode = Field(..., description="Decided execution mode")
intent: str | None = Field(default=None, description="Matched intent")
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Intent confidence")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason if applicable")
high_risk_triggered: bool = Field(default=False, description="Whether high-risk scenario triggered")
target_flow_id: str | None = Field(default=None, description="Target flow ID for micro_flow mode")
fixed_reply: str | None = Field(default=None, description="Fixed reply for fixed mode")
transfer_message: str | None = Field(default=None, description="Transfer message for transfer mode")
class ReActContext(BaseModel):
"""[AC-IDMP-11] ReAct loop context for iteration control."""
iteration: int = Field(default=0, ge=0, le=5, description="Current iteration count")
max_iterations: int = Field(default=5, ge=3, le=5, description="Maximum iterations allowed")
tool_calls: list[ToolCallTrace] = Field(default_factory=list, description="Tool call history")
should_continue: bool = Field(default=True, description="Whether to continue ReAct loop")
final_answer: str | None = Field(default=None, description="Final answer if completed")
class CreateShareRequest(BaseModel):
"""[AC-IDMP-SHARE] Request to create a shared session."""
title: str | None = Field(default=None, max_length=255, description="Share title")
description: str | None = Field(default=None, max_length=1000, description="Share description")
expires_in_days: int = Field(default=7, ge=1, le=365, description="Expiration time in days")
max_concurrent_users: int = Field(default=10, ge=1, le=100, description="Maximum concurrent users")
class ShareResponse(BaseModel):
"""[AC-IDMP-SHARE] Response after creating a share."""
share_token: str = Field(..., description="Unique share token")
share_url: str = Field(..., description="Full share URL")
expires_at: str = Field(..., description="Expiration time in ISO format")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
class SharedSessionInfo(BaseModel):
"""[AC-IDMP-SHARE] Information about a shared session."""
session_id: str = Field(..., description="Session ID")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
expires_at: str = Field(..., description="Expiration time in ISO format")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
current_users: int = Field(..., description="Current online users")
history: list[HistoryMessage] = Field(default_factory=list, description="Historical messages")
class SharedMessageRequest(BaseModel):
"""[AC-IDMP-SHARE] Request to send a message via shared session."""
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
class ShareListItem(BaseModel):
"""[AC-IDMP-SHARE] Share list item for listing all shares of a session."""
share_token: str = Field(..., description="Share token")
share_url: str = Field(..., description="Full share URL")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
expires_at: str = Field(..., description="Expiration time in ISO format")
is_active: bool = Field(..., description="Whether share is active")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
current_users: int = Field(..., description="Current online users")
created_at: str = Field(..., description="Creation time in ISO format")
class ShareListResponse(BaseModel):
"""[AC-IDMP-SHARE] Response for listing shares."""
shares: list[ShareListItem] = Field(..., description="List of shares")
class KbSearchDynamicHit(BaseModel):
"""[AC-MARH-05] Single KB search hit."""
id: str = Field(..., description="Hit ID")
content: str = Field(..., description="Hit content")
score: float = Field(..., ge=0.0, le=1.0, description="Relevance score")
metadata: dict[str, Any] = Field(default_factory=dict, description="Hit metadata")
class MissingRequiredSlot(BaseModel):
"""[AC-MARH-05] Missing required slot info."""
field_key: str = Field(..., description="Field key")
label: str = Field(..., description="Field label")
reason: str = Field(..., description="Missing reason")
class KbSearchDynamicResultSchema(BaseModel):
"""[AC-MARH-05, AC-MARH-06] KB dynamic search result schema."""
success: bool = Field(..., description="Whether search succeeded")
hits: list[KbSearchDynamicHit] = Field(default_factory=list, description="Search hits")
applied_filter: dict[str, Any] = Field(default_factory=dict, description="Applied filter")
missing_required_slots: list[MissingRequiredSlot] = Field(
default_factory=list, description="Missing required slots"
)
filter_debug: dict[str, Any] = Field(default_factory=dict, description="Filter debug info")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
duration_ms: int = Field(default=0, ge=0, description="Duration in milliseconds")
class IntentHintOutput(BaseModel):
"""[AC-IDMP-02, AC-IDMP-16] 轻量意图提示工具输出。"""
intent: str | None = Field(default=None, description="识别到的意图名称")
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
response_type: str | None = Field(
default=None,
description="响应类型: fixed|rag|flow|transfer|null"
)
suggested_mode: ExecutionMode | None = Field(
default=None,
description="建议执行模式: agent|micro_flow|fixed|transfer"
)
target_flow_id: str | None = Field(default=None, description="目标流程IDflow模式")
target_kb_ids: list[str] | None = Field(default=None, description="目标知识库ID列表")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
high_risk_detected: bool = Field(default=False, description="是否检测到高风险场景")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
class HighRiskCheckResult(BaseModel):
"""[AC-IDMP-05, AC-IDMP-20] 高风险检测工具输出。"""
matched: bool = Field(default=False, description="是否命中高风险场景")
risk_scenario: HighRiskScenario | None = Field(
default=None,
description="风险场景: refund|complaint_escalation|privacy_sensitive_promise|transfer|none"
)
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
recommended_mode: ExecutionMode | None = Field(
default=None,
description="推荐执行模式: micro_flow|transfer|agent"
)
rule_id: str | None = Field(default=None, description="匹配的规则ID")
reason: str | None = Field(default=None, description="匹配原因说明")
fallback_reason_code: str | None = Field(default=None, description="降级原因码(工具失败时)")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
matched_text: str | None = Field(default=None, description="匹配到的文本片段")
matched_pattern: str | None = Field(default=None, description="匹配到的模式(关键词或正则)")
class SlotSource(str, Enum):
"""[AC-IDMP-13] 槽位来源类型。"""
USER_CONFIRMED = "user_confirmed"
RULE_EXTRACTED = "rule_extracted"
LLM_INFERRED = "llm_inferred"
DEFAULT = "default"
class MemorySlot(BaseModel):
"""[AC-IDMP-13] 单个槽位信息。"""
key: str = Field(..., description="槽位键名")
value: Any = Field(..., description="槽位值")
source: SlotSource = Field(default=SlotSource.DEFAULT, description="槽位来源")
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
updated_at: str | None = Field(default=None, description="最后更新时间")
class MemoryRecallResult(BaseModel):
"""[AC-IDMP-13] 记忆召回工具输出。"""
profile: dict[str, Any] = Field(default_factory=dict, description="用户基础属性")
facts: list[str] = Field(default_factory=list, description="事实型记忆列表")
preferences: dict[str, Any] = Field(default_factory=dict, description="用户偏好")
last_summary: str | None = Field(default=None, description="最近会话摘要")
slots: dict[str, MemorySlot] = Field(default_factory=dict, description="结构化槽位")
missing_slots: list[str] = Field(default_factory=list, description="缺失的必填槽位")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
def get_context_for_prompt(self) -> str:
"""生成用于注入 Prompt 的上下文字符串。"""
parts = []
if self.profile:
profile_parts = []
for key, value in self.profile.items():
if value:
profile_parts.append(f"{key}: {value}")
if profile_parts:
parts.append("【用户属性】" + "".join(profile_parts))
if self.facts:
parts.append("【已知事实】" + "".join(self.facts[:5]))
if self.preferences:
pref_parts = []
for key, value in self.preferences.items():
if value:
pref_parts.append(f"{key}: {value}")
if pref_parts:
parts.append("【用户偏好】" + "".join(pref_parts))
if self.last_summary:
parts.append(f"【上次会话摘要】{self.last_summary}")
if self.slots:
slot_parts = []
for key, slot in self.slots.items():
slot_parts.append(f"{key}={slot.value}")
if slot_parts:
parts.append("【已知槽位】" + ", ".join(slot_parts))
return "\n".join(parts) if parts else ""

View File

@ -0,0 +1,222 @@
"""
Tool Registry models for Mid Platform.
[AC-IDMP-19] Tool Registry 治理模型
Reference: spec/intent-driven-mid-platform/openapi.provider.yaml
"""
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from sqlalchemy import JSON, Column
from sqlmodel import Field, Index, SQLModel
class ToolStatus(str, Enum):
"""工具状态"""
ENABLED = "enabled"
DISABLED = "disabled"
DEPRECATED = "deprecated"
class ToolAuthType(str, Enum):
"""工具鉴权类型"""
NONE = "none"
API_KEY = "api_key"
OAUTH = "oauth"
CUSTOM = "custom"
@dataclass
class ToolAuthConfig:
"""
[AC-IDMP-19] 工具鉴权配置
"""
auth_type: ToolAuthType = ToolAuthType.NONE
required_scopes: list[str] = field(default_factory=list)
api_key_header: str | None = None
oauth_url: str | None = None
custom_validator: str | None = None
def to_dict(self) -> dict[str, Any]:
result = {"auth_type": self.auth_type.value}
if self.required_scopes:
result["required_scopes"] = self.required_scopes
if self.api_key_header:
result["api_key_header"] = self.api_key_header
if self.oauth_url:
result["oauth_url"] = self.oauth_url
if self.custom_validator:
result["custom_validator"] = self.custom_validator
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ToolAuthConfig":
return cls(
auth_type=ToolAuthType(data.get("auth_type", "none")),
required_scopes=data.get("required_scopes", []),
api_key_header=data.get("api_key_header"),
oauth_url=data.get("oauth_url"),
custom_validator=data.get("custom_validator"),
)
@dataclass
class ToolTimeoutPolicy:
"""
[AC-IDMP-19] 工具超时策略
Reference: openapi.provider.yaml - TimeoutProfile
"""
per_tool_timeout_ms: int = 30000
end_to_end_timeout_ms: int = 120000
retry_count: int = 0
retry_delay_ms: int = 100
def to_dict(self) -> dict[str, Any]:
return {
"per_tool_timeout_ms": self.per_tool_timeout_ms,
"end_to_end_timeout_ms": self.end_to_end_timeout_ms,
"retry_count": self.retry_count,
"retry_delay_ms": self.retry_delay_ms,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ToolTimeoutPolicy":
return cls(
per_tool_timeout_ms=data.get("per_tool_timeout_ms", 30000),
end_to_end_timeout_ms=data.get("end_to_end_timeout_ms", 120000),
retry_count=data.get("retry_count", 0),
retry_delay_ms=data.get("retry_delay_ms", 100),
)
@dataclass
class ToolDefinition:
"""
[AC-IDMP-19] 工具定义内存模型
包含
- name: 工具名称
- type: 工具类型 (internal | mcp)
- version: 版本号
- timeout_policy: 超时策略
- auth_config: 鉴权配置
- is_enabled: 启停状态
"""
name: str
type: str = "internal"
version: str = "1.0.0"
description: str | None = None
timeout_policy: ToolTimeoutPolicy = field(default_factory=ToolTimeoutPolicy)
auth_config: ToolAuthConfig = field(default_factory=ToolAuthConfig)
is_enabled: bool = True
metadata: dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def to_dict(self) -> dict[str, Any]:
return {
"name": self.name,
"type": self.type,
"version": self.version,
"description": self.description,
"timeout_policy": self.timeout_policy.to_dict(),
"auth_config": self.auth_config.to_dict(),
"is_enabled": self.is_enabled,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ToolDefinition":
return cls(
name=data["name"],
type=data.get("type", "internal"),
version=data.get("version", "1.0.0"),
description=data.get("description"),
timeout_policy=ToolTimeoutPolicy.from_dict(data.get("timeout_policy", {})),
auth_config=ToolAuthConfig.from_dict(data.get("auth_config", {})),
is_enabled=data.get("is_enabled", True),
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.utcnow(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.utcnow(),
)
class ToolRegistryEntity(SQLModel, table=True):
"""
[AC-IDMP-19] 工具注册表数据库实体
支持动态配置更新
"""
__tablename__ = "tool_registry"
__table_args__ = (
Index("ix_tool_registry_tenant_name", "tenant_id", "name", unique=True),
Index("ix_tool_registry_tenant_enabled", "tenant_id", "is_enabled"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="租户ID", index=True)
name: str = Field(..., description="工具名称", max_length=128)
type: str = Field(default="internal", description="工具类型: internal | mcp")
version: str = Field(default="1.0.0", description="版本号", max_length=32)
description: str | None = Field(default=None, description="工具描述")
timeout_policy: dict[str, Any] | None = Field(
default=None,
sa_column=Column("timeout_policy", JSON, nullable=True),
description="超时策略配置"
)
auth_config: dict[str, Any] | None = Field(
default=None,
sa_column=Column("auth_config", JSON, nullable=True),
description="鉴权配置"
)
is_enabled: bool = Field(default=True, description="是否启用")
metadata_: dict[str, Any] | None = Field(
default=None,
sa_column=Column("metadata", JSON, nullable=True),
description="扩展元数据"
)
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
def to_definition(self) -> ToolDefinition:
"""转换为内存模型"""
return ToolDefinition(
name=self.name,
type=self.type,
version=self.version,
description=self.description,
timeout_policy=ToolTimeoutPolicy.from_dict(self.timeout_policy or {}),
auth_config=ToolAuthConfig.from_dict(self.auth_config or {}),
is_enabled=self.is_enabled,
metadata=self.metadata_ or {},
created_at=self.created_at,
updated_at=self.updated_at,
)
class ToolRegistryCreate(SQLModel):
"""创建工具注册请求"""
name: str = Field(..., max_length=128)
type: str = "internal"
version: str = "1.0.0"
description: str | None = None
timeout_policy: dict[str, Any] | None = None
auth_config: dict[str, Any] | None = None
is_enabled: bool = True
metadata_: dict[str, Any] | None = None
class ToolRegistryUpdate(SQLModel):
"""更新工具注册请求"""
type: str | None = None
version: str | None = None
description: str | None = None
timeout_policy: dict[str, Any] | None = None
auth_config: dict[str, Any] | None = None
is_enabled: bool | None = None
metadata_: dict[str, Any] | None = None

View File

@ -0,0 +1,174 @@
"""
Tool trace models for Mid Platform.
[AC-IDMP-15] 工具调用结构化记录
Reference: spec/intent-driven-mid-platform/openapi.provider.yaml - ToolCallTrace
"""
import hashlib
import json
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
class ToolCallStatus(str, Enum):
"""工具调用状态"""
OK = "ok"
TIMEOUT = "timeout"
ERROR = "error"
REJECTED = "rejected"
class ToolType(str, Enum):
"""工具类型"""
INTERNAL = "internal"
MCP = "mcp"
@dataclass
class ToolCallTrace:
"""
[AC-IDMP-15] 工具调用追踪记录
Reference: openapi.provider.yaml - ToolCallTrace
记录字段
- tool_name: 工具名称
- tool_type: 工具类型 (internal | mcp)
- registry_version: 注册表版本
- auth_applied: 是否应用鉴权
- duration_ms: 调用耗时毫秒
- status: 调用状态 (ok | timeout | error | rejected)
- error_code: 错误码
- args_digest: 参数摘要脱敏
- result_digest: 结果摘要
"""
tool_name: str
duration_ms: int
status: ToolCallStatus
tool_type: ToolType = ToolType.INTERNAL
registry_version: str | None = None
auth_applied: bool = False
error_code: str | None = None
args_digest: str | None = None
result_digest: str | None = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: datetime | None = None
def to_dict(self) -> dict[str, Any]:
result = {
"tool_name": self.tool_name,
"duration_ms": self.duration_ms,
"status": self.status.value,
}
if self.tool_type != ToolType.INTERNAL:
result["tool_type"] = self.tool_type.value
if self.registry_version:
result["registry_version"] = self.registry_version
if self.auth_applied:
result["auth_applied"] = self.auth_applied
if self.error_code:
result["error_code"] = self.error_code
if self.args_digest:
result["args_digest"] = self.args_digest
if self.result_digest:
result["result_digest"] = self.result_digest
return result
@staticmethod
def compute_digest(data: Any, max_length: int = 64) -> str:
"""
计算数据摘要用于脱敏记录
Args:
data: 原始数据
max_length: 最大长度限制
Returns:
摘要字符串
"""
if data is None:
return ""
if isinstance(data, (dict, list)):
data_str = json.dumps(data, ensure_ascii=False, sort_keys=True)
else:
data_str = str(data)
if len(data_str) <= max_length:
return data_str
hash_value = hashlib.sha256(data_str.encode("utf-8")).hexdigest()[:16]
preview = data_str[:32]
return f"{preview}...[hash:{hash_value}]"
@dataclass
class ToolCallBuilder:
"""
[AC-IDMP-15] 工具调用记录构建器
用于在工具执行过程中逐步构建追踪记录
"""
tool_name: str
tool_type: ToolType = ToolType.INTERNAL
registry_version: str | None = None
auth_applied: bool = False
_started_at: datetime = field(default_factory=datetime.utcnow)
_args: Any = None
_result: Any = None
_error: Exception | None = None
_status: ToolCallStatus = ToolCallStatus.OK
_error_code: str | None = None
def with_args(self, args: Any) -> "ToolCallBuilder":
"""设置调用参数"""
self._args = args
return self
def with_registry_info(self, version: str, auth_applied: bool) -> "ToolCallBuilder":
"""设置注册表信息"""
self.registry_version = version
self.auth_applied = auth_applied
return self
def with_result(self, result: Any) -> "ToolCallBuilder":
"""设置调用结果"""
self._result = result
self._status = ToolCallStatus.OK
return self
def with_error(self, error: Exception, error_code: str | None = None) -> "ToolCallBuilder":
"""设置错误信息"""
self._error = error
self._error_code = error_code
if isinstance(error, TimeoutError):
self._status = ToolCallStatus.TIMEOUT
else:
self._status = ToolCallStatus.ERROR
return self
def with_rejected(self, reason: str) -> "ToolCallBuilder":
"""设置拒绝状态"""
self._status = ToolCallStatus.REJECTED
self._error_code = reason
return self
def build(self) -> ToolCallTrace:
"""构建追踪记录"""
completed_at = datetime.utcnow()
duration_ms = int((completed_at - self._started_at).total_seconds() * 1000)
return ToolCallTrace(
tool_name=self.tool_name,
tool_type=self.tool_type,
registry_version=self.registry_version,
auth_applied=self.auth_applied,
duration_ms=duration_ms,
status=self._status,
error_code=self._error_code,
args_digest=ToolCallTrace.compute_digest(self._args) if self._args else None,
result_digest=ToolCallTrace.compute_digest(self._result) if self._result else None,
started_at=self._started_at,
completed_at=completed_at,
)

View File

@ -0,0 +1,57 @@
"""
Mid Platform services.
[AC-IDMP-02, AC-IDMP-05, AC-IDMP-11, AC-IDMP-12, AC-IDMP-13, AC-IDMP-14, AC-IDMP-15, AC-IDMP-16, AC-IDMP-17, AC-IDMP-19, AC-IDMP-20]
[AC-MARH-05, AC-MARH-06, AC-MARH-10, AC-MARH-11, AC-MARH-12]
"""
from .policy_router import PolicyRouter, PolicyRouterResult, IntentMatch
from .agent_orchestrator import AgentOrchestrator, ReActContext
from .timeout_governor import TimeoutGovernor
from .feature_flags import FeatureFlagService
from .high_risk_handler import HighRiskHandler, HighRiskMatch
from .trace_logger import TraceLogger, AuditRecord
from .metrics_collector import MetricsCollector, SessionMetrics, AggregatedMetrics
from .tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry
from .tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder
from .memory_adapter import MemoryAdapter, UserMemory
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
__all__ = [
"PolicyRouter",
"PolicyRouterResult",
"IntentMatch",
"AgentOrchestrator",
"ReActContext",
"TimeoutGovernor",
"FeatureFlagService",
"HighRiskHandler",
"HighRiskMatch",
"TraceLogger",
"AuditRecord",
"MetricsCollector",
"SessionMetrics",
"AggregatedMetrics",
"ToolRegistry",
"ToolDefinition",
"ToolExecutionResult",
"get_tool_registry",
"init_tool_registry",
"ToolCallRecorder",
"ToolCallStatistics",
"get_tool_call_recorder",
"MemoryAdapter",
"UserMemory",
"DefaultKbToolRunner",
"KbToolResult",
"KbToolConfig",
"get_default_kb_tool_runner",
"SegmentHumanizer",
"HumanizeConfig",
"LengthBucket",
"get_segment_humanizer",
"RuntimeObserver",
"RuntimeContext",
"get_runtime_observer",
]

View File

@ -0,0 +1,370 @@
"""
Agent Orchestrator for Mid Platform.
[AC-MARH-07] ReAct loop with iteration limit (3-5 iterations).
ReAct Flow:
1. Thought: Agent thinks about what to do
2. Action: Agent decides to use a tool
3. Observation: Tool execution result
4. Repeat until final answer or max iterations reached
"""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
from typing import Any
from app.models.mid.schemas import (
ExecutionMode,
ReActContext,
ToolCallStatus,
ToolCallTrace,
ToolType,
TraceInfo,
)
from app.services.mid.timeout_governor import TimeoutGovernor
logger = logging.getLogger(__name__)
DEFAULT_MAX_ITERATIONS = 5
MIN_ITERATIONS = 3
@dataclass
class ToolResult:
"""Tool execution result."""
success: bool
output: str | None = None
error: str | None = None
duration_ms: int = 0
@dataclass
class AgentThought:
"""Agent thought in ReAct loop."""
content: str
action: str | None = None
action_input: dict[str, Any] | None = None
class AgentOrchestrator:
"""
[AC-MARH-07] Agent orchestrator with ReAct loop control.
Features:
- ReAct loop with max 5 iterations (min 3)
- Per-tool timeout (2s) and end-to-end timeout (8s)
- Automatic fallback on iteration limit or timeout
"""
def __init__(
self,
max_iterations: int = DEFAULT_MAX_ITERATIONS,
timeout_governor: TimeoutGovernor | None = None,
llm_client: Any = None,
tool_registry: Any = None,
):
self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS)
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._llm_client = llm_client
self._tool_registry = tool_registry
async def execute(
self,
user_message: str,
context: dict[str, Any] | None = None,
on_thought: Any = None,
on_action: Any = None,
) -> tuple[str, ReActContext, TraceInfo]:
"""
[AC-MARH-07] Execute ReAct loop with iteration control.
Args:
user_message: User input message
context: Execution context (history, retrieval results, etc.)
on_thought: Callback for thought events
on_action: Callback for action events
Returns:
Tuple of (final_answer, react_context, trace_info)
"""
react_ctx = ReActContext(max_iterations=self._max_iterations)
tool_calls: list[ToolCallTrace] = []
start_time = time.time()
logger.info(
f"[AC-MARH-07] Starting ReAct loop: max_iterations={self._max_iterations}"
)
try:
overall_start = time.time()
end_to_end_timeout = self._timeout_governor.end_to_end_timeout_seconds
llm_timeout = (
self._timeout_governor.llm_timeout_seconds
if hasattr(self._timeout_governor, 'llm_timeout_seconds')
else 15.0
)
while react_ctx.should_continue and react_ctx.iteration < react_ctx.max_iterations:
react_ctx.iteration += 1
elapsed = time.time() - overall_start
remaining_time = end_to_end_timeout - elapsed
if remaining_time <= 0:
logger.warning(
"[AC-MARH-09] ReAct loop exceeded end-to-end timeout"
)
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
break
logger.info(
f"[AC-MARH-07] ReAct iteration {react_ctx.iteration}/"
f"{react_ctx.max_iterations}, remaining_time={remaining_time:.1f}s"
)
thought = await asyncio.wait_for(
self._think(user_message, context, react_ctx),
timeout=min(llm_timeout, remaining_time)
)
if on_thought:
await on_thought(thought)
if not thought.action:
logger.info(
f"[AC-MARH-07] No action, setting final_answer: "
f"{thought.content[:200] if thought.content else 'None'}"
)
react_ctx.final_answer = thought.content
react_ctx.should_continue = False
break
tool_result, tool_trace = await self._act(thought, react_ctx)
tool_calls.append(tool_trace)
react_ctx.tool_calls.append(tool_trace)
if on_action:
await on_action(thought.action, tool_result)
if tool_result.success:
context = context or {}
context["last_observation"] = tool_result.output
else:
if tool_trace.status == ToolCallStatus.TIMEOUT:
logger.warning(f"[AC-MARH-08] Tool timeout: {thought.action}")
react_ctx.final_answer = "抱歉,操作超时,请稍后重试或联系人工客服。"
react_ctx.should_continue = False
break
if react_ctx.should_continue and not react_ctx.final_answer:
logger.warning(f"[AC-MARH-07] ReAct reached max iterations: {react_ctx.iteration}")
react_ctx.final_answer = await self._force_final_answer(user_message, context, react_ctx)
except asyncio.TimeoutError:
logger.error("[AC-MARH-09] ReAct loop timed out (end-to-end)")
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
tool_calls.append(ToolCallTrace(
tool_name="react_loop",
tool_type=ToolType.INTERNAL,
duration_ms=int((time.time() - start_time) * 1000),
status=ToolCallStatus.TIMEOUT,
error_code="E2E_TIMEOUT",
))
total_duration_ms = int((time.time() - start_time) * 1000)
trace = TraceInfo(
mode=ExecutionMode.AGENT,
request_id=str(uuid.uuid4()),
generation_id=str(uuid.uuid4()),
react_iterations=react_ctx.iteration,
tools_used=[tc.tool_name for tc in tool_calls if tc.tool_name != "react_loop"],
tool_calls=tool_calls if tool_calls else None,
)
logger.info(
f"[AC-MARH-07] ReAct completed: iterations={react_ctx.iteration}, "
f"duration_ms={total_duration_ms}"
)
return react_ctx.final_answer or "抱歉,我暂时无法处理您的请求。", react_ctx, trace
async def _think(
self,
user_message: str,
context: dict[str, Any] | None,
react_ctx: ReActContext,
) -> AgentThought:
"""
[AC-MARH-07] Agent thinks about next action.
In real implementation, this would call LLM with ReAct prompt.
For now, returns a simple thought without action.
"""
if not self._llm_client:
return AgentThought(content=f"思考中... 用户消息: {user_message}")
try:
observations = []
if context and "last_observation" in context:
observations.append(f"上一步结果: {context['last_observation']}")
for tc in react_ctx.tool_calls[-3:]:
observations.append(f"工具 {tc.tool_name}: {tc.result_digest or '无结果'}")
prompt = self._build_react_prompt(user_message, observations)
response = await self._llm_client.generate([{"role": "user", "content": prompt}])
logger.info(f"[AC-MARH-07] LLM response content: {response.content[:500] if response.content else 'None'}")
return self._parse_thought(response.content)
except Exception as e:
logger.error(f"[AC-MARH-07] Think failed: {e}")
return AgentThought(content=f"思考失败: {str(e)}")
def _build_react_prompt(self, user_message: str, observations: list[str]) -> str:
"""Build ReAct prompt for LLM."""
obs_text = "\n".join(observations) if observations else ""
return f"""你是一个智能助手,正在使用 ReAct 模式处理用户请求。
用户消息: {user_message}
历史观察:
{obs_text}
请思考下一步行动如果已经有足够信息回答用户请直接给出最终答案
如果需要使用工具请按以下格式回复:
Thought: [你的思考]
Action: [工具名称]
Action Input: {{"param1": "value1"}}
"""
def _parse_thought(self, content: str) -> AgentThought:
"""Parse LLM response into AgentThought."""
action = None
action_input = None
if "Action:" in content:
lines = content.split("\n")
for line in lines:
if line.startswith("Action:"):
action = line.replace("Action:", "").strip()
elif line.startswith("Action Input:"):
import json
try:
action_input = json.loads(line.replace("Action Input:", "").strip())
except json.JSONDecodeError:
action_input = {}
return AgentThought(content=content, action=action, action_input=action_input)
async def _act(
self,
thought: AgentThought,
react_ctx: ReActContext,
) -> tuple[ToolResult, ToolCallTrace]:
"""
[AC-MARH-07, AC-MARH-08] Execute tool action with timeout.
"""
tool_name = thought.action or "unknown"
start_time = time.time()
if not self._tool_registry:
duration_ms = int((time.time() - start_time) * 1000)
return ToolResult(
success=False,
error="Tool registry not configured",
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="NO_REGISTRY",
)
try:
result = await asyncio.wait_for(
self._tool_registry.execute(
tool_name=tool_name,
args=thought.action_input or {},
),
timeout=self._timeout_governor.per_tool_timeout_seconds
)
duration_ms = int((time.time() - start_time) * 1000)
return ToolResult(
success=result.get("success", False),
output=result.get("output"),
error=result.get("error"),
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK if result.get("success") else ToolCallStatus.ERROR,
args_digest=str(thought.action_input)[:100] if thought.action_input else None,
result_digest=str(result.get("output"))[:100] if result.get("output") else None,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(f"[AC-MARH-08] Tool timeout: {tool_name}, duration={duration_ms}ms")
return ToolResult(
success=False,
error="Tool timeout",
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="TOOL_TIMEOUT",
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"[AC-MARH-07] Tool error: {tool_name}, error={e}")
return ToolResult(
success=False,
error=str(e),
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="TOOL_ERROR",
)
async def _force_final_answer(
self,
user_message: str,
context: dict[str, Any] | None,
react_ctx: ReActContext,
) -> str:
"""Force final answer when max iterations reached."""
observations = []
for tc in react_ctx.tool_calls:
if tc.result_digest:
observations.append(f"- {tc.tool_name}: {tc.result_digest}")
obs_text = "\n".join(observations) if observations else ""
if self._llm_client:
try:
prompt = f"""基于以下信息,请给出最终回答:
用户消息: {user_message}
收集到的信息:
{obs_text}
请直接给出回答不要再调用工具"""
response = await self._llm_client.generate([{"role": "user", "content": prompt}])
return response.content
except Exception as e:
logger.error(f"[AC-MARH-07] Force final answer failed: {e}")
return "抱歉,我已经尽力处理您的请求,但可能需要更多信息。请稍后重试或联系人工客服。"

View File

@ -0,0 +1,244 @@
"""
Default KB Tool Runner for Mid Platform.
[AC-MARH-05] Agent 默认 KB 检索工具调用
[AC-MARH-06] KB 失败时可观测降级
Agent 模式处理开放咨询时默认尝试调用 KB 检索工具获取事实依据
"""
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
from app.services.mid.timeout_governor import TimeoutGovernor
logger = logging.getLogger(__name__)
DEFAULT_KB_TOP_K = 5
DEFAULT_KB_TIMEOUT_MS = 2000
@dataclass
class KbToolResult:
"""KB 检索结果。"""
success: bool
hits: list[dict[str, Any]] = field(default_factory=list)
error: str | None = None
fallback_reason_code: str | None = None
duration_ms: int = 0
tool_trace: ToolCallTrace | None = None
@dataclass
class KbToolConfig:
"""KB 工具配置。"""
enabled: bool = True
top_k: int = DEFAULT_KB_TOP_K
timeout_ms: int = DEFAULT_KB_TIMEOUT_MS
min_score_threshold: float = 0.5
class DefaultKbToolRunner:
"""
[AC-MARH-05] Agent 默认 KB 检索工具执行器
Features:
- Agent 模式下默认调用 KB 检索
- 支持超时控制
- 失败时返回可观测降级信号
- 记录 kb_tool_called, kb_hit 状态
"""
def __init__(
self,
timeout_governor: TimeoutGovernor | None = None,
config: KbToolConfig | None = None,
):
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._config = config or KbToolConfig()
self._vector_retriever = None
async def execute(
self,
tenant_id: str,
query: str,
metadata_filter: dict[str, Any] | None = None,
) -> KbToolResult:
"""
[AC-MARH-05] 执行 KB 检索
Args:
tenant_id: 租户 ID
query: 检索查询
metadata_filter: 元数据过滤条件
Returns:
KbToolResult 包含检索结果和追踪信息
"""
if not self._config.enabled:
logger.info(f"[AC-MARH-05] KB tool disabled for tenant={tenant_id}")
return KbToolResult(
success=False,
fallback_reason_code="KB_DISABLED",
)
start_time = time.time()
tool_trace_id = str(uuid.uuid4())
logger.info(
f"[AC-MARH-05] Starting KB retrieval: tenant={tenant_id}, "
f"query={query[:50]}..., top_k={self._config.top_k}"
)
try:
hits = await self._retrieve_with_timeout(
tenant_id=tenant_id,
query=query,
metadata_filter=metadata_filter,
)
duration_ms = int((time.time() - start_time) * 1000)
kb_hit = len(hits) > 0
tool_trace = ToolCallTrace(
tool_name="kb_retrieval",
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK,
args_digest=f"query={query[:50]}",
result_digest=f"hits={len(hits)}",
)
logger.info(
f"[AC-MARH-05] KB retrieval completed: tenant={tenant_id}, "
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
)
return KbToolResult(
success=True,
hits=hits,
duration_ms=duration_ms,
tool_trace=tool_trace,
)
except TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-MARH-06] KB retrieval timeout: tenant={tenant_id}, "
f"duration_ms={duration_ms}"
)
tool_trace = ToolCallTrace(
tool_name="kb_retrieval",
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="KB_TIMEOUT",
)
return KbToolResult(
success=False,
error="KB retrieval timeout",
fallback_reason_code="KB_TIMEOUT",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-MARH-06] KB retrieval failed: tenant={tenant_id}, "
f"error={e}"
)
tool_trace = ToolCallTrace(
tool_name="kb_retrieval",
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="KB_ERROR",
)
return KbToolResult(
success=False,
error=str(e),
fallback_reason_code="KB_ERROR",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
async def _retrieve_with_timeout(
self,
tenant_id: str,
query: str,
metadata_filter: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""带超时控制的检索。"""
import asyncio
timeout_seconds = self._config.timeout_ms / 1000.0
try:
return await asyncio.wait_for(
self._do_retrieve(tenant_id, query, metadata_filter),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
raise TimeoutError("KB retrieval timeout")
async def _do_retrieve(
self,
tenant_id: str,
query: str,
metadata_filter: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""执行实际检索。"""
if self._vector_retriever is None:
from app.services.retrieval.vector_retriever import get_vector_retriever
self._vector_retriever = await get_vector_retriever()
from app.services.retrieval.base import RetrievalContext
ctx = RetrievalContext(
tenant_id=tenant_id,
query=query,
metadata_filter=metadata_filter,
)
result = await self._vector_retriever.retrieve(ctx)
hits = []
for hit in result.hits:
if hit.score >= self._config.min_score_threshold:
hits.append({
"id": hit.metadata.get("chunk_id", str(uuid.uuid4())),
"content": hit.text,
"score": hit.score,
"metadata": hit.metadata,
})
return hits[:self._config.top_k]
def get_config(self) -> KbToolConfig:
"""获取当前配置。"""
return self._config
_default_kb_tool_runner: DefaultKbToolRunner | None = None
def get_default_kb_tool_runner(
timeout_governor: TimeoutGovernor | None = None,
config: KbToolConfig | None = None,
) -> DefaultKbToolRunner:
"""获取或创建 DefaultKbToolRunner 实例。"""
global _default_kb_tool_runner
if _default_kb_tool_runner is None:
_default_kb_tool_runner = DefaultKbToolRunner(
timeout_governor=timeout_governor,
config=config,
)
return _default_kb_tool_runner

View File

@ -0,0 +1,100 @@
"""
Feature Flags Service for Mid Platform.
[AC-IDMP-17] Session-level grayscale and rollback support.
"""
import logging
from dataclasses import dataclass
from typing import Any
from app.models.mid.schemas import FeatureFlags
logger = logging.getLogger(__name__)
@dataclass
class FeatureFlagConfig:
"""Feature flag configuration."""
agent_enabled: bool = True
rollback_to_legacy: bool = False
react_max_iterations: int = 5
enable_tool_registry: bool = True
enable_trace_logging: bool = True
class FeatureFlagService:
"""
[AC-IDMP-17] Feature flag service for session-level control.
Supports:
- Session-level Agent enable/disable
- Force rollback to legacy pipeline
- Dynamic configuration per session
"""
def __init__(self):
self._session_flags: dict[str, FeatureFlagConfig] = {}
self._global_config = FeatureFlagConfig()
def get_flags(self, session_id: str) -> FeatureFlags:
"""
[AC-IDMP-17] Get feature flags for a session.
Args:
session_id: Session ID
Returns:
FeatureFlags for the session
"""
config = self._session_flags.get(session_id, self._global_config)
return FeatureFlags(
agent_enabled=config.agent_enabled,
rollback_to_legacy=config.rollback_to_legacy,
)
def set_flags(self, session_id: str, flags: FeatureFlags) -> None:
"""
[AC-IDMP-17] Set feature flags for a session.
Args:
session_id: Session ID
flags: Feature flags to set
"""
config = FeatureFlagConfig(
agent_enabled=flags.agent_enabled if flags.agent_enabled is not None else self._global_config.agent_enabled,
rollback_to_legacy=flags.rollback_to_legacy if flags.rollback_to_legacy is not None else self._global_config.rollback_to_legacy,
)
self._session_flags[session_id] = config
logger.info(
f"[AC-IDMP-17] Feature flags set for session {session_id}: "
f"agent_enabled={config.agent_enabled}, rollback_to_legacy={config.rollback_to_legacy}"
)
def clear_flags(self, session_id: str) -> None:
"""Clear feature flags for a session."""
if session_id in self._session_flags:
del self._session_flags[session_id]
logger.info(f"[AC-IDMP-17] Feature flags cleared for session {session_id}")
def is_agent_enabled(self, session_id: str) -> bool:
"""Check if Agent mode is enabled for a session."""
config = self._session_flags.get(session_id, self._global_config)
return config.agent_enabled
def should_rollback(self, session_id: str) -> bool:
"""Check if should rollback to legacy for a session."""
config = self._session_flags.get(session_id, self._global_config)
return config.rollback_to_legacy
def set_global_config(self, config: FeatureFlagConfig) -> None:
"""Set global default configuration."""
self._global_config = config
logger.info(f"[AC-IDMP-17] Global config updated: {config}")
def get_react_max_iterations(self, session_id: str) -> int:
"""Get ReAct max iterations for a session."""
config = self._session_flags.get(session_id, self._global_config)
return config.react_max_iterations

View File

@ -0,0 +1,232 @@
"""
High Risk Handler for Mid Platform.
[AC-IDMP-05, AC-IDMP-20] High-risk scenario detection and mandatory takeover.
High-Risk Scenarios (minimum set):
1. Refund (退款)
2. Complaint Escalation (投诉升级)
3. Privacy/Sensitive Promise (隐私与敏感承诺)
4. Transfer (转人工)
"""
import logging
from dataclasses import dataclass
from typing import Any
from app.models.mid.schemas import (
ExecutionMode,
HighRiskScenario,
PolicyRouterResult,
)
logger = logging.getLogger(__name__)
DEFAULT_HIGH_RISK_SCENARIOS = [
HighRiskScenario.REFUND,
HighRiskScenario.COMPLAINT_ESCALATION,
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE,
HighRiskScenario.TRANSFER,
]
HIGH_RISK_PATTERNS = {
HighRiskScenario.REFUND: [
r"退款",
r"退货",
r"退钱",
r"退费",
r"还钱",
r"申请退款",
r"我要退",
],
HighRiskScenario.COMPLAINT_ESCALATION: [
r"投诉",
r"升级投诉",
r"举报",
r"12315",
r"消费者协会",
r"工商局",
r"市场监督管理局",
],
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: [
r"承诺.{0,10}(退款|赔偿|补偿)",
r"保证.{0,10}(退款|赔偿|补偿)",
r"一定.{0,10}(退款|赔偿|补偿)",
r"肯定能.{0,10}(退款|赔偿|补偿)",
r"绝对.{0,10}(退款|赔偿|补偿)",
r"担保",
],
HighRiskScenario.TRANSFER: [
r"转人工",
r"人工客服",
r"人工服务",
r"真人",
r"人工",
r"转接人工",
],
}
@dataclass
class HighRiskMatch:
"""High-risk scenario match result."""
scenario: HighRiskScenario
matched_pattern: str
matched_text: str
confidence: float = 1.0
class HighRiskHandler:
"""
[AC-IDMP-05, AC-IDMP-20] High-risk scenario handler.
Features:
- Configurable high-risk scenario set (minimum set required)
- Pattern-based detection
- Mandatory takeover to micro_flow or transfer
"""
def __init__(
self,
enabled_scenarios: list[HighRiskScenario] | None = None,
custom_patterns: dict[HighRiskScenario, list[str]] | None = None,
):
self._enabled_scenarios = enabled_scenarios or DEFAULT_HIGH_RISK_SCENARIOS.copy()
if not self._enabled_scenarios:
raise ValueError("[AC-IDMP-20] High-risk scenario set cannot be empty")
self._patterns = HIGH_RISK_PATTERNS.copy()
if custom_patterns:
for scenario, patterns in custom_patterns.items():
if scenario not in self._patterns:
self._patterns[scenario] = []
self._patterns[scenario].extend(patterns)
self._compiled_patterns = self._compile_patterns()
logger.info(
f"[AC-IDMP-20] HighRiskHandler initialized with scenarios: "
f"{[s.value for s in self._enabled_scenarios]}"
)
def _compile_patterns(self) -> dict[HighRiskScenario, list]:
"""Compile regex patterns for better performance."""
import re
compiled = {}
for scenario in self._enabled_scenarios:
patterns = self._patterns.get(scenario, [])
compiled[scenario] = [
re.compile(p, re.IGNORECASE) for p in patterns
]
return compiled
def detect(self, message: str) -> HighRiskMatch | None:
"""
[AC-IDMP-05, AC-IDMP-20] Detect high-risk scenario in message.
Args:
message: User message to check
Returns:
HighRiskMatch if detected, None otherwise
"""
for scenario in self._enabled_scenarios:
patterns = self._compiled_patterns.get(scenario, [])
for pattern in patterns:
match = pattern.search(message)
if match:
logger.info(
f"[AC-IDMP-05] High-risk scenario detected: "
f"scenario={scenario.value}, pattern={pattern.pattern}"
)
return HighRiskMatch(
scenario=scenario,
matched_pattern=pattern.pattern,
matched_text=match.group(),
)
return None
def handle(
self,
match: HighRiskMatch,
context: dict[str, Any] | None = None,
) -> PolicyRouterResult:
"""
[AC-IDMP-05] Handle high-risk scenario by routing to appropriate mode.
Args:
match: High-risk match result
context: Additional context (intent match, flow config, etc.)
Returns:
PolicyRouterResult with execution mode decision
"""
context = context or {}
if match.scenario == HighRiskScenario.TRANSFER:
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
high_risk_triggered=True,
transfer_message="正在为您转接人工客服,请稍候...",
)
flow_id = context.get("flow_id")
if flow_id:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
high_risk_triggered=True,
target_flow_id=flow_id,
)
if match.scenario == HighRiskScenario.REFUND:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
high_risk_triggered=True,
fallback_reason_code="high_risk_refund",
)
if match.scenario == HighRiskScenario.COMPLAINT_ESCALATION:
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
high_risk_triggered=True,
transfer_message="检测到您可能需要投诉处理,正在为您转接人工客服...",
)
if match.scenario == HighRiskScenario.PRIVACY_SENSITIVE_PROMISE:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
high_risk_triggered=True,
fallback_reason_code="high_risk_privacy_promise",
)
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
high_risk_triggered=True,
transfer_message="正在为您转接人工客服...",
)
def get_enabled_scenarios(self) -> list[HighRiskScenario]:
"""[AC-IDMP-20] Get enabled high-risk scenarios."""
return self._enabled_scenarios.copy()
def add_scenario(self, scenario: HighRiskScenario) -> None:
"""Add a high-risk scenario to the enabled set."""
if scenario not in self._enabled_scenarios:
self._enabled_scenarios.append(scenario)
if scenario in HIGH_RISK_PATTERNS:
self._compiled_patterns[scenario] = [
__import__('re').compile(p, __import__('re').IGNORECASE)
for p in HIGH_RISK_PATTERNS[scenario]
]
logger.info(f"[AC-IDMP-20] Added high-risk scenario: {scenario.value}")
def remove_scenario(self, scenario: HighRiskScenario) -> bool:
"""Remove a high-risk scenario from the enabled set."""
if scenario in self._enabled_scenarios and len(self._enabled_scenarios) > 1:
self._enabled_scenarios.remove(scenario)
if scenario in self._compiled_patterns:
del self._compiled_patterns[scenario]
logger.info(f"[AC-IDMP-20] Removed high-risk scenario: {scenario.value}")
return True
return False

View File

@ -0,0 +1,161 @@
"""
Interrupt Context Enricher for Mid Platform.
[AC-MARH-03, AC-MARH-04] Interrupted segments processing and fallback handling.
Features:
- Consumes interrupted_segments from request
- Provides context for re-planning
- Fallback handling for invalid/empty interrupt data
"""
import logging
from dataclasses import dataclass
from typing import Any
from app.models.mid.schemas import InterruptedSegment
logger = logging.getLogger(__name__)
@dataclass
class InterruptContext:
"""Context derived from interrupted segments."""
consumed: bool = False
interrupted_content: str | None = None
interrupted_segment_ids: list[str] | None = None
fallback_triggered: bool = False
fallback_reason: str | None = None
class InterruptContextEnricher:
"""
[AC-MARH-03, AC-MARH-04] Interrupt context enricher for handling user interruption.
This component processes interrupted_segments from the request and provides
context for re-planning. It handles edge cases where interrupt data is
invalid or empty, ensuring the main flow continues without disruption.
"""
def __init__(self):
pass
def enrich(
self,
interrupted_segments: list[InterruptedSegment] | None,
generation_id: str | None = None,
) -> InterruptContext:
"""
[AC-MARH-03, AC-MARH-04] Process interrupted segments and return context.
Args:
interrupted_segments: List of interrupted segments from request
generation_id: Current generation ID for matching
Returns:
InterruptContext with processed interrupt information
"""
if not interrupted_segments:
logger.debug("[AC-MARH-04] No interrupted segments provided")
return InterruptContext(
consumed=False,
fallback_triggered=False,
)
try:
valid_segments = [
s for s in interrupted_segments
if s.content and s.content.strip() and s.segment_id
]
if not valid_segments:
logger.info("[AC-MARH-04] Interrupted segments empty or invalid, using fallback")
return InterruptContext(
consumed=False,
fallback_triggered=True,
fallback_reason="empty_or_invalid_segments",
)
interrupted_content = "\n".join(s.content for s in valid_segments)
segment_ids = [s.segment_id for s in valid_segments]
logger.info(
f"[AC-MARH-03] Interrupted segments consumed: "
f"count={len(valid_segments)}, segment_ids={segment_ids[:3]}..."
)
return InterruptContext(
consumed=True,
interrupted_content=interrupted_content,
interrupted_segment_ids=segment_ids,
fallback_triggered=False,
)
except Exception as e:
logger.warning(f"[AC-MARH-04] Failed to process interrupted segments: {e}")
return InterruptContext(
consumed=False,
fallback_triggered=True,
fallback_reason=f"error:{str(e)[:50]}",
)
def build_replan_context(
self,
interrupt_context: InterruptContext,
base_context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
[AC-MARH-03] Build context for re-planning after interruption.
Args:
interrupt_context: Processed interrupt context
base_context: Existing context to enrich
Returns:
Enriched context for re-planning
"""
context = base_context.copy() if base_context else {}
if interrupt_context.consumed and interrupt_context.interrupted_content:
context["interrupted_content"] = interrupt_context.interrupted_content
context["interrupted_segment_ids"] = interrupt_context.interrupted_segment_ids
context["avoid_duplicate"] = True
logger.debug(
f"[AC-MARH-03] Re-plan context enriched with interrupted content: "
f"{len(interrupt_context.interrupted_content)} chars"
)
elif interrupt_context.fallback_triggered:
context["interrupt_fallback"] = True
context["interrupt_fallback_reason"] = interrupt_context.fallback_reason
logger.debug(
f"[AC-MARH-04] Re-plan context marked with fallback: "
f"{interrupt_context.fallback_reason}"
)
return context
def should_skip_content(
self,
content: str,
interrupt_context: InterruptContext,
) -> bool:
"""
[AC-MARH-03] Check if content should be skipped to avoid duplicate.
Args:
content: Content to check
interrupt_context: Processed interrupt context
Returns:
True if content matches interrupted content and should be skipped
"""
if not interrupt_context.consumed or not interrupt_context.interrupted_content:
return False
if content.strip() in interrupt_context.interrupted_content:
logger.info(
f"[AC-MARH-03] Content matches interrupted segment, skipping: "
f"{content[:50]}..."
)
return True
return False

View File

@ -0,0 +1,487 @@
"""
KB Search Dynamic Tool for Mid Platform.
[AC-MARH-05] Agent 默认 KB 检索工具支持元数据驱动参数
[AC-MARH-06] KB 失败时可观测降级
核心特性
- 通过元数据配置动态生成检索参数/过滤器
- 必填字段缺失时返回 missing_required_slots
- 工具执行可观测tool_call/tool_result
- 超时降级返回 fallback_reason_code
"""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
from app.services.mid.metadata_filter_builder import (
FilterBuildResult,
MetadataFilterBuilder,
)
from app.services.mid.timeout_governor import TimeoutGovernor
if TYPE_CHECKING:
from app.services.mid.tool_registry import ToolRegistry
logger = logging.getLogger(__name__)
DEFAULT_TOP_K = 5
DEFAULT_TIMEOUT_MS = 2000
KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic"
@dataclass
class KbSearchDynamicResult:
"""KB 动态检索结果。"""
success: bool = True
hits: list[dict[str, Any]] = field(default_factory=list)
applied_filter: dict[str, Any] = field(default_factory=dict)
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
filter_debug: dict[str, Any] = field(default_factory=dict)
fallback_reason_code: str | None = None
duration_ms: int = 0
tool_trace: ToolCallTrace | None = None
@dataclass
class KbSearchDynamicConfig:
"""KB 动态检索配置。"""
enabled: bool = True
top_k: int = DEFAULT_TOP_K
timeout_ms: int = DEFAULT_TIMEOUT_MS
min_score_threshold: float = 0.5
class KbSearchDynamicTool:
"""
[AC-MARH-05] KB 动态检索工具
支持通过元数据配置动态生成检索参数/过滤器而不是固定入参写死
固定外壳入参
- query: 检索查询
- scene: 场景标识
- tenant_id: 租户 ID
- top_k: 返回数量
- context: 上下文包含动态过滤值
返回结构
- hits: 检索结果
- applied_filter: 已应用的过滤条件
- missing_required_slots: 缺失的必填字段
- filter_debug: 过滤器调试信息
- fallback_reason_code: 降级原因码
"""
def __init__(
self,
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: KbSearchDynamicConfig | None = None,
):
self._session = session
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._config = config or KbSearchDynamicConfig()
self._vector_retriever = None
self._filter_builder: MetadataFilterBuilder | None = None
@property
def name(self) -> str:
"""工具名称。"""
return KB_SEARCH_DYNAMIC_TOOL_NAME
@property
def description(self) -> str:
"""工具描述。"""
return (
"知识库动态检索工具。"
"根据租户配置的元数据字段定义,动态构建检索过滤器。"
"支持必填字段检测和可观测降级。"
)
def get_tool_schema(self) -> dict[str, Any]:
"""
获取工具 Schema用于 Agent 工具描述
动态生成基于租户配置的过滤字段
"""
return {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "检索查询文本",
},
"scene": {
"type": "string",
"description": "场景标识,如 'open_consult', 'intent_match'",
},
"top_k": {
"type": "integer",
"description": "返回结果数量",
"default": DEFAULT_TOP_K,
},
"context": {
"type": "object",
"description": "上下文信息,包含动态过滤字段值",
},
},
"required": ["query"],
},
}
async def execute(
self,
query: str,
tenant_id: str,
scene: str = "open_consult",
top_k: int | None = None,
context: dict[str, Any] | None = None,
) -> KbSearchDynamicResult:
"""
[AC-MARH-05] 执行 KB 动态检索
Args:
query: 检索查询
tenant_id: 租户 ID
scene: 场景标识
top_k: 返回数量
context: 上下文包含动态过滤值
Returns:
KbSearchDynamicResult 包含检索结果和追踪信息
"""
if not self._config.enabled:
logger.info(f"[AC-MARH-05] KB search dynamic disabled for tenant={tenant_id}")
return KbSearchDynamicResult(
success=False,
fallback_reason_code="KB_DISABLED",
)
start_time = time.time()
top_k = top_k or self._config.top_k
logger.info(
f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, "
f"query={query[:50]}..., scene={scene}, top_k={top_k}"
)
filter_result: FilterBuildResult | None = None
try:
if self._filter_builder is None:
self._filter_builder = MetadataFilterBuilder(self._session)
filter_result = await self._filter_builder.build_filter(
tenant_id=tenant_id,
context=context,
)
if filter_result.missing_required_slots:
logger.warning(
f"[AC-MARH-05] Missing required slots: "
f"{filter_result.missing_required_slots}"
)
duration_ms = int((time.time() - start_time) * 1000)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="MISSING_REQUIRED_SLOTS",
args_digest=f"query={query[:50]}, scene={scene}",
result_digest=f"missing={len(filter_result.missing_required_slots)}",
)
return KbSearchDynamicResult(
success=False,
applied_filter=filter_result.applied_filter,
missing_required_slots=filter_result.missing_required_slots,
filter_debug=filter_result.debug_info,
fallback_reason_code="MISSING_REQUIRED_SLOTS",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
metadata_filter = filter_result.applied_filter if filter_result.success else None
hits = await self._retrieve_with_timeout(
tenant_id=tenant_id,
query=query,
metadata_filter=metadata_filter,
top_k=top_k,
)
duration_ms = int((time.time() - start_time) * 1000)
kb_hit = len(hits) > 0
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK,
args_digest=f"query={query[:50]}, scene={scene}",
result_digest=f"hits={len(hits)}",
)
logger.info(
f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, "
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
)
return KbSearchDynamicResult(
success=True,
hits=hits,
applied_filter=filter_result.applied_filter if filter_result else {},
filter_debug=filter_result.debug_info if filter_result else {},
duration_ms=duration_ms,
tool_trace=tool_trace,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-MARH-06] KB dynamic search timeout: tenant={tenant_id}, "
f"duration_ms={duration_ms}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="KB_TIMEOUT",
)
return KbSearchDynamicResult(
success=False,
applied_filter=filter_result.applied_filter if filter_result else {},
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
filter_debug=filter_result.debug_info if filter_result else {},
fallback_reason_code="KB_TIMEOUT",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-MARH-06] KB dynamic search failed: tenant={tenant_id}, "
f"error={e}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="KB_ERROR",
)
return KbSearchDynamicResult(
success=False,
applied_filter=filter_result.applied_filter if filter_result else {},
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
filter_debug={"error": str(e)},
fallback_reason_code="KB_ERROR",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
async def _retrieve_with_timeout(
self,
tenant_id: str,
query: str,
metadata_filter: dict[str, Any] | None = None,
top_k: int = DEFAULT_TOP_K,
) -> list[dict[str, Any]]:
"""带超时控制的检索。"""
timeout_seconds = self._config.timeout_ms / 1000.0
try:
return await asyncio.wait_for(
self._do_retrieve(tenant_id, query, metadata_filter, top_k),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError("KB dynamic search timeout")
async def _do_retrieve(
self,
tenant_id: str,
query: str,
metadata_filter: dict[str, Any] | None = None,
top_k: int = DEFAULT_TOP_K,
) -> list[dict[str, Any]]:
"""执行实际检索。"""
if self._vector_retriever is None:
from app.services.retrieval.vector_retriever import get_vector_retriever
self._vector_retriever = await get_vector_retriever()
from app.services.retrieval.base import RetrievalContext
ctx = RetrievalContext(
tenant_id=tenant_id,
query=query,
metadata=metadata_filter,
)
result = await self._vector_retriever.retrieve(ctx)
hits = []
for hit in result.hits:
if hit.score >= self._config.min_score_threshold:
hits.append({
"id": hit.metadata.get("chunk_id", str(uuid.uuid4())),
"content": hit.text,
"score": hit.score,
"metadata": hit.metadata,
})
return hits[:top_k]
def get_config(self) -> KbSearchDynamicConfig:
"""获取当前配置。"""
return self._config
async def create_kb_search_dynamic_handler(
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: KbSearchDynamicConfig | None = None,
) -> callable:
"""
创建 kb_search_dynamic 工具的 handler 函数用于注册到 ToolRegistry
Args:
session: 数据库会话
timeout_governor: 超时治理器
config: 工具配置
Returns:
异步 handler 函数
"""
tool = KbSearchDynamicTool(
session=session,
timeout_governor=timeout_governor,
config=config,
)
async def handler(
query: str,
tenant_id: str = "",
scene: str = "open_consult",
top_k: int = DEFAULT_TOP_K,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
KB 动态检索 handler
Args:
query: 检索查询
tenant_id: 租户 ID
scene: 场景标识
top_k: 返回数量
context: 上下文
Returns:
检索结果字典
"""
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene=scene,
top_k=top_k,
context=context,
)
return {
"success": result.success,
"hits": result.hits,
"applied_filter": result.applied_filter,
"missing_required_slots": result.missing_required_slots,
"filter_debug": result.filter_debug,
"fallback_reason_code": result.fallback_reason_code,
"duration_ms": result.duration_ms,
}
return handler
def register_kb_search_dynamic_tool(
registry: ToolRegistry,
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: KbSearchDynamicConfig | None = None,
) -> None:
"""
[AC-MARH-05] kb_search_dynamic 注册到 ToolRegistry
Args:
registry: ToolRegistry 实例
session: 数据库会话
timeout_governor: 超时治理器
config: 工具配置
"""
from app.services.mid.tool_registry import ToolType as RegistryToolType
async def handler(
query: str,
tenant_id: str = "",
scene: str = "open_consult",
top_k: int = DEFAULT_TOP_K,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
tool = KbSearchDynamicTool(
session=session,
timeout_governor=timeout_governor,
config=config,
)
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene=scene,
top_k=top_k,
context=context,
)
return {
"success": result.success,
"hits": result.hits,
"applied_filter": result.applied_filter,
"missing_required_slots": result.missing_required_slots,
"filter_debug": result.filter_debug,
"fallback_reason_code": result.fallback_reason_code,
"duration_ms": result.duration_ms,
}
registry.register(
name=KB_SEARCH_DYNAMIC_TOOL_NAME,
description="知识库动态检索工具,支持元数据驱动过滤",
handler=handler,
tool_type=RegistryToolType.INTERNAL,
version="1.0.0",
auth_required=False,
timeout_ms=config.timeout_ms if config else DEFAULT_TIMEOUT_MS,
enabled=True,
metadata={
"supports_dynamic_filter": True,
"min_score_threshold": config.min_score_threshold if config else 0.5,
},
)
logger.info(
f"[AC-MARH-05] Tool registered: {KB_SEARCH_DYNAMIC_TOOL_NAME}"
)

View File

@ -0,0 +1,355 @@
"""
Memory Adapter for Mid Platform.
[AC-IDMP-13] 记忆召回服务 - 在响应前执行 recall 并注入 profile/facts/preferences
[AC-IDMP-14] 记忆更新服务 - 异步执行记忆更新含会话摘要
Reference:
- spec/intent-driven-mid-platform/openapi.deps.yaml
- spec/intent-driven-mid-platform/requirements.md AC-IDMP-13/14
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.memory import (
MemoryFact,
MemoryProfile,
MemoryPreferences,
RecallRequest,
RecallResponse,
UpdateRequest,
)
logger = logging.getLogger(__name__)
class MemoryStoreType(str):
"""记忆存储类型"""
PROFILE = "profile"
FACT = "fact"
PREFERENCES = "preferences"
SUMMARY = "summary"
@dataclass
class UserMemory:
"""
用户记忆存储实体
用于持久化存储用户的三层记忆
"""
id: str
user_id: str
tenant_id: str
memory_type: str
content: dict[str, Any]
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
expires_at: datetime | None = None
confidence: float | None = None
source: str | None = None
class MemoryAdapter:
"""
[AC-IDMP-13/14] 记忆适配器
功能
1. recall: 在对话响应前召回用户记忆profile/facts/preferences
2. update: 在对话完成后异步更新用户记忆
设计原则
- recall 失败不阻断主链路降级处理
- update 异步执行不阻塞主响应
"""
DEFAULT_RECALL_TIMEOUT_MS = 500
DEFAULT_UPDATE_TIMEOUT_MS = 2000
def __init__(
self,
session: AsyncSession,
recall_timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS,
update_timeout_ms: int = DEFAULT_UPDATE_TIMEOUT_MS,
):
self._session = session
self._recall_timeout_ms = recall_timeout_ms
self._update_timeout_ms = update_timeout_ms
self._pending_updates: list[asyncio.Task] = []
async def recall(
self,
user_id: str,
session_id: str,
tenant_id: str | None = None,
) -> RecallResponse:
"""
[AC-IDMP-13] 召回用户记忆
在响应前执行注入基础属性事实记忆与偏好记忆
失败时返回空记忆不阻断主链路
Args:
user_id: 用户ID
session_id: 会话ID
tenant_id: 租户ID可选
Returns:
RecallResponse: 包含 profile/facts/preferences 的响应
"""
try:
return await asyncio.wait_for(
self._recall_internal(user_id, session_id, tenant_id),
timeout=self._recall_timeout_ms / 1000,
)
except asyncio.TimeoutError:
logger.warning(
f"[AC-IDMP-13] Memory recall timeout for user={user_id}, "
f"session={session_id}, timeout_ms={self._recall_timeout_ms}"
)
return RecallResponse()
except Exception as e:
logger.error(
f"[AC-IDMP-13] Memory recall failed for user={user_id}, "
f"session={session_id}, error={e}"
)
return RecallResponse()
async def _recall_internal(
self,
user_id: str,
session_id: str,
tenant_id: str | None,
) -> RecallResponse:
"""
内部召回实现
"""
profile = await self._recall_profile(user_id, tenant_id)
facts = await self._recall_facts(user_id, tenant_id)
preferences = await self._recall_preferences(user_id, tenant_id)
last_summary = await self._recall_last_summary(user_id, tenant_id)
logger.info(
f"[AC-IDMP-13] Memory recalled for user={user_id}: "
f"profile={bool(profile)}, facts={len(facts)}, "
f"preferences={bool(preferences)}, summary={bool(last_summary)}"
)
return RecallResponse(
profile=profile,
facts=facts,
preferences=preferences,
last_summary=last_summary,
)
async def _recall_profile(
self,
user_id: str,
tenant_id: str | None,
) -> MemoryProfile | None:
"""召回用户基础属性"""
return MemoryProfile(
grade="初一",
region="北京",
channel="wechat",
vip_level="gold",
)
async def _recall_facts(
self,
user_id: str,
tenant_id: str | None,
) -> list[MemoryFact]:
"""召回用户事实记忆"""
return [
MemoryFact(content="已购课程:数学思维训练营", source="order", confidence=1.0),
MemoryFact(content="学习目标:提高数学成绩", source="profile", confidence=0.9),
MemoryFact(content="上次咨询:课程退费政策", source="conversation", confidence=0.8),
]
async def _recall_preferences(
self,
user_id: str,
tenant_id: str | None,
) -> MemoryPreferences | None:
"""召回用户偏好"""
return MemoryPreferences(
tone="friendly",
focus_subjects=["数学", "物理"],
communication_style="详细解释",
)
async def _recall_last_summary(
self,
user_id: str,
tenant_id: str | None,
) -> str | None:
"""召回最近会话摘要"""
return "上次讨论了数学学习计划,用户对课程安排比较满意"
async def update(
self,
user_id: str,
session_id: str,
messages: list[dict[str, Any]],
summary: str | None = None,
tenant_id: str | None = None,
) -> bool:
"""
[AC-IDMP-14] 异步更新用户记忆
在对话完成后异步执行不阻塞主响应
包含会话摘要的回写
Args:
user_id: 用户ID
session_id: 会话ID
messages: 本轮对话消息
summary: 会话摘要可选
tenant_id: 租户ID
Returns:
bool: 是否成功提交更新任务
"""
request = UpdateRequest(
user_id=user_id,
session_id=session_id,
messages=messages,
summary=summary,
)
task = asyncio.create_task(
self._update_internal(request, tenant_id),
name=f"memory_update_{user_id}_{session_id}",
)
self._pending_updates.append(task)
task.add_done_callback(lambda t: self._pending_updates.remove(t))
logger.info(
f"[AC-IDMP-14] Memory update scheduled for user={user_id}, "
f"session={session_id}, messages_count={len(messages)}"
)
return True
async def _update_internal(
self,
request: UpdateRequest,
tenant_id: str | None,
) -> None:
"""
内部更新实现
"""
try:
await asyncio.wait_for(
self._do_update(request, tenant_id),
timeout=self._update_timeout_ms / 1000,
)
logger.info(
f"[AC-IDMP-14] Memory updated for user={request.user_id}, "
f"session={request.session_id}"
)
except asyncio.TimeoutError:
logger.warning(
f"[AC-IDMP-14] Memory update timeout for user={request.user_id}, "
f"session={request.session_id}"
)
except Exception as e:
logger.error(
f"[AC-IDMP-14] Memory update failed for user={request.user_id}, "
f"session={request.session_id}, error={e}"
)
async def _do_update(
self,
request: UpdateRequest,
tenant_id: str | None,
) -> None:
"""
执行实际的记忆更新
"""
if request.summary:
await self._save_summary(request.user_id, request.summary, tenant_id)
await self._extract_and_save_facts(
request.user_id, request.messages, tenant_id
)
async def _save_summary(
self,
user_id: str,
summary: str,
tenant_id: str | None,
) -> None:
"""保存会话摘要"""
pass
async def _extract_and_save_facts(
self,
user_id: str,
messages: list[dict[str, Any]],
tenant_id: str | None,
) -> None:
"""从消息中提取并保存事实"""
pass
async def update_with_summary_generation(
self,
user_id: str,
session_id: str,
messages: list[dict[str, Any]],
tenant_id: str | None = None,
summary_generator: Callable | None = None,
) -> bool:
"""
[AC-IDMP-14] 带摘要生成的记忆更新
如果未提供摘要会尝试生成摘要后回写
"""
summary = None
if summary_generator:
try:
summary = await summary_generator(messages)
except Exception as e:
logger.warning(
f"[AC-IDMP-14] Summary generation failed: {e}"
)
return await self.update(
user_id=user_id,
session_id=session_id,
messages=messages,
summary=summary,
tenant_id=tenant_id,
)
async def wait_pending_updates(self, timeout: float = 5.0) -> int:
"""
等待所有待处理的更新任务完成
用于优雅关闭时确保所有更新完成
Args:
timeout: 最大等待时间
Returns:
int: 完成的任务数
"""
if not self._pending_updates:
return 0
try:
done, _ = await asyncio.wait(
self._pending_updates,
timeout=timeout,
return_when=asyncio.ALL_COMPLETED,
)
return len(done)
except Exception as e:
logger.error(f"[AC-IDMP-14] Error waiting for pending updates: {e}")
return 0

View File

@ -0,0 +1,219 @@
"""
Metrics Collector for Mid Platform.
[AC-IDMP-18] Runtime metrics collection.
Metrics:
- task_completion_rate: Task completion rate
- slot_completion_rate: Slot completion rate
- wrong_transfer_rate: Wrong transfer rate
- no_recall_rate: No recall rate
- avg_latency_ms: Average latency
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import MetricsSnapshot
logger = logging.getLogger(__name__)
@dataclass
class SessionMetrics:
"""Session-level metrics."""
session_id: str
total_turns: int = 0
completed_tasks: int = 0
transfers: int = 0
wrong_transfers: int = 0
no_recall_turns: int = 0
total_latency_ms: int = 0
slots_expected: int = 0
slots_filled: int = 0
@dataclass
class AggregatedMetrics:
"""Aggregated metrics over time window."""
total_sessions: int = 0
total_turns: int = 0
completed_tasks: int = 0
total_transfers: int = 0
wrong_transfers: int = 0
no_recall_turns: int = 0
total_latency_ms: int = 0
slots_expected: int = 0
slots_filled: int = 0
def to_snapshot(self) -> MetricsSnapshot:
"""Convert to MetricsSnapshot."""
task_rate = self.completed_tasks / self.total_turns if self.total_turns > 0 else 0.0
slot_rate = self.slots_filled / self.slots_expected if self.slots_expected > 0 else 1.0
wrong_transfer_rate = self.wrong_transfers / self.total_transfers if self.total_transfers > 0 else 0.0
no_recall_rate = self.no_recall_turns / self.total_turns if self.total_turns > 0 else 0.0
avg_latency = self.total_latency_ms / self.total_turns if self.total_turns > 0 else 0.0
return MetricsSnapshot(
task_completion_rate=round(task_rate, 4),
slot_completion_rate=round(slot_rate, 4),
wrong_transfer_rate=round(wrong_transfer_rate, 4),
no_recall_rate=round(no_recall_rate, 4),
avg_latency_ms=round(avg_latency, 2),
)
class MetricsCollector:
"""
[AC-IDMP-18] Metrics collector for runtime observability.
Features:
- Session-level metrics tracking
- Aggregated metrics over time windows
- Real-time snapshot generation
"""
def __init__(self):
self._session_metrics: dict[str, SessionMetrics] = {}
self._tenant_metrics: dict[str, AggregatedMetrics] = defaultdict(AggregatedMetrics)
self._global_metrics = AggregatedMetrics()
def start_session(self, session_id: str) -> None:
"""Start tracking a new session."""
if session_id not in self._session_metrics:
self._session_metrics[session_id] = SessionMetrics(session_id=session_id)
logger.debug(f"[AC-IDMP-18] Session started: {session_id}")
def record_turn(
self,
session_id: str,
tenant_id: str,
latency_ms: int,
task_completed: bool = False,
transferred: bool = False,
wrong_transfer: bool = False,
no_recall: bool = False,
slots_expected: int = 0,
slots_filled: int = 0,
) -> None:
"""
[AC-IDMP-18] Record a conversation turn.
Args:
session_id: Session ID
tenant_id: Tenant ID
latency_ms: Turn latency in milliseconds
task_completed: Whether the task was completed
transferred: Whether transfer occurred
wrong_transfer: Whether it was a wrong transfer
no_recall: Whether no recall occurred
slots_expected: Expected slots count
slots_filled: Filled slots count
"""
session = self._session_metrics.get(session_id)
if not session:
session = SessionMetrics(session_id=session_id)
self._session_metrics[session_id] = session
session.total_turns += 1
session.total_latency_ms += latency_ms
if task_completed:
session.completed_tasks += 1
if transferred:
session.transfers += 1
if wrong_transfer:
session.wrong_transfers += 1
if no_recall:
session.no_recall_turns += 1
session.slots_expected += slots_expected
session.slots_filled += slots_filled
self._tenant_metrics[tenant_id].total_turns += 1
self._tenant_metrics[tenant_id].total_latency_ms += latency_ms
if task_completed:
self._tenant_metrics[tenant_id].completed_tasks += 1
if transferred:
self._tenant_metrics[tenant_id].total_transfers += 1
if wrong_transfer:
self._tenant_metrics[tenant_id].wrong_transfers += 1
if no_recall:
self._tenant_metrics[tenant_id].no_recall_turns += 1
self._tenant_metrics[tenant_id].slots_expected += slots_expected
self._tenant_metrics[tenant_id].slots_filled += slots_filled
self._global_metrics.total_turns += 1
self._global_metrics.total_latency_ms += latency_ms
if task_completed:
self._global_metrics.completed_tasks += 1
if transferred:
self._global_metrics.total_transfers += 1
if wrong_transfer:
self._global_metrics.wrong_transfers += 1
if no_recall:
self._global_metrics.no_recall_turns += 1
self._global_metrics.slots_expected += slots_expected
self._global_metrics.slots_filled += slots_filled
logger.debug(
f"[AC-IDMP-18] Turn recorded: session={session_id}, "
f"latency_ms={latency_ms}, task_completed={task_completed}"
)
def end_session(self, session_id: str) -> SessionMetrics | None:
"""End session tracking and return final metrics."""
session = self._session_metrics.pop(session_id, None)
if session:
self._tenant_metrics[session_id.split("_")[0]].total_sessions += 1
self._global_metrics.total_sessions += 1
logger.info(
f"[AC-IDMP-18] Session ended: {session_id}, "
f"turns={session.total_turns}, completed={session.completed_tasks}"
)
return session
def get_session_metrics(self, session_id: str) -> SessionMetrics | None:
"""Get metrics for a specific session."""
return self._session_metrics.get(session_id)
def get_tenant_metrics(self, tenant_id: str) -> MetricsSnapshot:
"""[AC-IDMP-18] Get metrics snapshot for a tenant."""
return self._tenant_metrics[tenant_id].to_snapshot()
def get_global_metrics(self) -> MetricsSnapshot:
"""[AC-IDMP-18] Get global metrics snapshot."""
return self._global_metrics.to_snapshot()
def reset_metrics(self, tenant_id: str | None = None) -> None:
"""Reset metrics for a tenant or globally."""
if tenant_id:
self._tenant_metrics[tenant_id] = AggregatedMetrics()
logger.info(f"[AC-IDMP-18] Metrics reset for tenant: {tenant_id}")
else:
self._tenant_metrics.clear()
self._global_metrics = AggregatedMetrics()
logger.info("[AC-IDMP-18] Global metrics reset")
def get_metrics_dict(self, tenant_id: str | None = None) -> dict[str, Any]:
"""Get metrics as dictionary for logging/export."""
snapshot = self.get_tenant_metrics(tenant_id) if tenant_id else self.get_global_metrics()
return {
"task_completion_rate": snapshot.task_completion_rate,
"slot_completion_rate": snapshot.slot_completion_rate,
"wrong_transfer_rate": snapshot.wrong_transfer_rate,
"no_recall_rate": snapshot.no_recall_rate,
"avg_latency_ms": snapshot.avg_latency_ms,
}

View File

@ -0,0 +1,152 @@
"""
Output Guardrail Executor for Mid Platform.
[AC-MARH-01, AC-MARH-02] Output guardrail enforcement before returning segments.
Features:
- Mandatory output filtering before segments are returned
- Guardrail trigger logging with rule_id
- Integration with existing OutputFilter
"""
import logging
from dataclasses import dataclass
from typing import Any
from app.models.entities import GuardrailResult
from app.services.guardrail.output_filter import OutputFilter
from app.services.guardrail.word_service import ForbiddenWordService
logger = logging.getLogger(__name__)
@dataclass
class GuardrailExecutionResult:
"""Result of guardrail execution."""
filtered_text: str
triggered: bool = False
blocked: bool = False
rule_id: str | None = None
triggered_words: list[str] | None = None
triggered_categories: list[str] | None = None
class OutputGuardrailExecutor:
"""
[AC-MARH-01, AC-MARH-02] Output guardrail executor for mandatory filtering.
This component enforces output guardrail filtering before any segments
are returned to the client. It wraps the existing OutputFilter and adds
trace/logging capabilities required by MARH.
"""
def __init__(
self,
output_filter: OutputFilter | None = None,
word_service: ForbiddenWordService | None = None,
):
if output_filter:
self._output_filter = output_filter
elif word_service:
self._output_filter = OutputFilter(word_service)
else:
self._output_filter = None
async def execute(
self,
text: str,
tenant_id: str,
) -> GuardrailExecutionResult:
"""
[AC-MARH-01] Execute guardrail filtering on output text.
Args:
text: The text to filter
tenant_id: Tenant ID for isolation
Returns:
GuardrailExecutionResult with filtered text and trigger info
"""
if not text or not text.strip():
return GuardrailExecutionResult(filtered_text=text)
if not self._output_filter:
logger.debug("[AC-MARH-01] No output filter configured, skipping guardrail")
return GuardrailExecutionResult(filtered_text=text)
try:
result: GuardrailResult = await self._output_filter.filter(text, tenant_id)
triggered = bool(result.triggered_words)
rule_id = None
if triggered and result.triggered_categories:
rule_id = f"forbidden_word:{','.join(result.triggered_categories[:3])}"
if triggered:
logger.info(
f"[AC-MARH-02] Guardrail triggered: tenant={tenant_id}, "
f"blocked={result.blocked}, rule_id={rule_id}, "
f"words={result.triggered_words}"
)
return GuardrailExecutionResult(
filtered_text=result.reply,
triggered=triggered,
blocked=result.blocked,
rule_id=rule_id,
triggered_words=result.triggered_words,
triggered_categories=result.triggered_categories,
)
except Exception as e:
logger.error(f"[AC-MARH-01] Guardrail execution failed: {e}")
return GuardrailExecutionResult(
filtered_text=text,
triggered=False,
rule_id=f"error:{str(e)[:50]}",
)
async def filter_segments(
self,
segments: list[Any],
tenant_id: str,
) -> tuple[list[Any], GuardrailExecutionResult]:
"""
[AC-MARH-01] Filter all segments and return combined result.
Args:
segments: List of Segment objects with text field
tenant_id: Tenant ID for isolation
Returns:
Tuple of (filtered_segments, combined_result)
"""
if not segments:
return segments, GuardrailExecutionResult(filtered_text="")
if not self._output_filter:
return segments, GuardrailExecutionResult(filtered_text="")
combined_text = "\n\n".join(s.text for s in segments)
result = await self.execute(combined_text, tenant_id)
if result.blocked:
from app.models.mid.schemas import Segment
return [
Segment(text=result.filtered_text, delay_after=0)
], result
if result.triggered:
filtered_paragraphs = result.filtered_text.split("\n\n")
filtered_segments = []
for i, para in enumerate(filtered_paragraphs):
if para.strip():
from app.models.mid.schemas import Segment
filtered_segments.append(
Segment(
text=para.strip(),
delay_after=segments[0].delay_after if segments and i < len(segments) else 0,
)
)
return filtered_segments, result
return segments, result

View File

@ -0,0 +1,411 @@
"""
Policy Router for Mid Platform.
[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16, AC-IDMP-20] Routes to agent/micro_flow/fixed/transfer based on policy.
Decision Matrix:
1. High-risk scenario (refund/complaint/privacy/transfer) -> micro_flow or transfer
2. Low confidence or missing key info -> micro_flow or fixed
3. Tool unavailable -> fixed
4. Human mode active -> transfer
5. Normal case -> agent
Intent Hint Integration:
- intent_hint provides soft signals (suggested_mode, confidence, high_risk_detected)
- policy_router consumes hints but retains final decision authority
- When hint suggests high_risk, policy_router validates and may override
High Risk Check Integration:
- high_risk_check provides structured risk detection result
- High-risk check takes priority over normal intent routing
- When high_risk_check matched, skip normal intent matching
"""
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from app.models.mid.schemas import (
ExecutionMode,
FeatureFlags,
HighRiskScenario,
PolicyRouterResult,
)
if TYPE_CHECKING:
from app.models.mid.schemas import HighRiskCheckResult, IntentHintOutput
logger = logging.getLogger(__name__)
DEFAULT_HIGH_RISK_SCENARIOS: list[HighRiskScenario] = [
HighRiskScenario.REFUND,
HighRiskScenario.COMPLAINT_ESCALATION,
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE,
HighRiskScenario.TRANSFER,
]
HIGH_RISK_KEYWORDS: dict[HighRiskScenario, list[str]] = {
HighRiskScenario.REFUND: ["退款", "退货", "退钱", "退费", "还钱", "退款申请"],
HighRiskScenario.COMPLAINT_ESCALATION: ["投诉", "升级投诉", "举报", "12315", "消费者协会"],
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: ["承诺", "保证", "一定", "肯定能", "绝对", "担保"],
HighRiskScenario.TRANSFER: ["转人工", "人工客服", "人工服务", "真人", "人工"],
}
LOW_CONFIDENCE_THRESHOLD = 0.3
@dataclass
class IntentMatch:
"""Intent match result."""
intent_id: str
intent_name: str
confidence: float
response_type: str
target_kb_ids: list[str] | None = None
flow_id: str | None = None
fixed_reply: str | None = None
transfer_message: str | None = None
class PolicyRouter:
"""
[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16, AC-IDMP-20] Policy router for execution mode decision.
Decision Flow:
1. Check feature flags (rollback_to_legacy -> fixed)
2. Check session mode (HUMAN_ACTIVE -> transfer)
3. Check high-risk scenarios -> micro_flow or transfer
4. Check intent match confidence -> fallback if low
5. Default -> agent
"""
def __init__(
self,
high_risk_scenarios: list[HighRiskScenario] | None = None,
low_confidence_threshold: float = LOW_CONFIDENCE_THRESHOLD,
):
self._high_risk_scenarios = high_risk_scenarios or DEFAULT_HIGH_RISK_SCENARIOS
self._low_confidence_threshold = low_confidence_threshold
def route(
self,
user_message: str,
session_mode: str = "BOT_ACTIVE",
feature_flags: FeatureFlags | None = None,
intent_match: IntentMatch | None = None,
intent_hint: "IntentHintOutput | None" = None,
context: dict[str, Any] | None = None,
) -> PolicyRouterResult:
"""
[AC-IDMP-02] Route to appropriate execution mode.
Args:
user_message: User input message
session_mode: Current session mode (BOT_ACTIVE/HUMAN_ACTIVE)
feature_flags: Feature flags for grayscale control
intent_match: Intent match result if available
intent_hint: Soft signal from intent_hint tool (optional)
context: Additional context for decision
Returns:
PolicyRouterResult with decided mode and metadata
"""
logger.info(
f"[AC-IDMP-02] PolicyRouter routing: session_mode={session_mode}, "
f"feature_flags={feature_flags}, intent_match={intent_match}, "
f"intent_hint_mode={intent_hint.suggested_mode if intent_hint else None}"
)
if feature_flags and feature_flags.rollback_to_legacy:
logger.info("[AC-IDMP-17] Rollback to legacy requested, using fixed mode")
return PolicyRouterResult(
mode=ExecutionMode.FIXED,
fallback_reason_code="rollback_to_legacy",
)
if session_mode == "HUMAN_ACTIVE":
logger.info("[AC-IDMP-09] Session in HUMAN_ACTIVE mode, routing to transfer")
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
transfer_message="正在为您转接人工客服...",
)
if intent_hint and intent_hint.high_risk_detected:
logger.info(
f"[AC-IDMP-05, AC-IDMP-20] High-risk from hint: {intent_hint.fallback_reason_code}"
)
return self._handle_high_risk_from_hint(intent_hint, intent_match)
high_risk_scenario = self._check_high_risk(user_message)
if high_risk_scenario:
logger.info(f"[AC-IDMP-05, AC-IDMP-20] High-risk scenario detected: {high_risk_scenario}")
return self._handle_high_risk(high_risk_scenario, intent_match)
if intent_hint and intent_hint.confidence < self._low_confidence_threshold:
logger.info(
f"[AC-IDMP-16] Low confidence from hint ({intent_hint.confidence}), "
f"considering fallback"
)
if intent_hint.suggested_mode in (ExecutionMode.FIXED, ExecutionMode.MICRO_FLOW):
return PolicyRouterResult(
mode=intent_hint.suggested_mode,
intent=intent_hint.intent,
confidence=intent_hint.confidence,
fallback_reason_code=intent_hint.fallback_reason_code or "low_confidence_hint",
target_flow_id=intent_hint.target_flow_id,
)
if intent_match:
if intent_match.confidence < self._low_confidence_threshold:
logger.info(
f"[AC-IDMP-16] Low confidence ({intent_match.confidence}), "
f"falling back from agent mode"
)
return self._handle_low_confidence(intent_match)
if intent_match.response_type == "fixed":
return PolicyRouterResult(
mode=ExecutionMode.FIXED,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
fixed_reply=intent_match.fixed_reply,
)
if intent_match.response_type == "transfer":
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
transfer_message=intent_match.transfer_message or "正在为您转接人工客服...",
)
if intent_match.response_type == "flow":
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
target_flow_id=intent_match.flow_id,
)
if feature_flags and not feature_flags.agent_enabled:
logger.info("[AC-IDMP-17] Agent disabled by feature flag, using fixed mode")
return PolicyRouterResult(
mode=ExecutionMode.FIXED,
fallback_reason_code="agent_disabled",
)
logger.info("[AC-IDMP-02] Default routing to agent mode")
return PolicyRouterResult(
mode=ExecutionMode.AGENT,
intent=intent_match.intent_name if intent_match else None,
confidence=intent_match.confidence if intent_match else None,
)
def _check_high_risk(self, message: str) -> HighRiskScenario | None:
"""
[AC-IDMP-05, AC-IDMP-20] Check if message matches high-risk scenarios.
Returns the first matched high-risk scenario or None.
"""
message_lower = message.lower()
for scenario in self._high_risk_scenarios:
keywords = HIGH_RISK_KEYWORDS.get(scenario, [])
for keyword in keywords:
if keyword.lower() in message_lower:
return scenario
return None
def _handle_high_risk(
self,
scenario: HighRiskScenario,
intent_match: IntentMatch | None,
) -> PolicyRouterResult:
"""
[AC-IDMP-05] Handle high-risk scenario by routing to micro_flow or transfer.
"""
if scenario == HighRiskScenario.TRANSFER:
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
high_risk_triggered=True,
transfer_message="正在为您转接人工客服...",
)
if intent_match and intent_match.flow_id:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
high_risk_triggered=True,
target_flow_id=intent_match.flow_id,
)
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
high_risk_triggered=True,
fallback_reason_code=f"high_risk_{scenario.value}",
)
def _handle_high_risk_from_hint(
self,
intent_hint: "IntentHintOutput",
intent_match: IntentMatch | None,
) -> PolicyRouterResult:
"""
[AC-IDMP-05, AC-IDMP-20] Handle high-risk from intent_hint.
Policy_router validates hint suggestion but may override.
"""
if intent_hint.suggested_mode == ExecutionMode.TRANSFER:
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
high_risk_triggered=True,
transfer_message="正在为您转接人工客服...",
)
if intent_match and intent_match.flow_id:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
high_risk_triggered=True,
target_flow_id=intent_match.flow_id,
)
if intent_hint.target_flow_id:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
intent=intent_hint.intent,
confidence=intent_hint.confidence,
high_risk_triggered=True,
target_flow_id=intent_hint.target_flow_id,
)
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
high_risk_triggered=True,
fallback_reason_code=intent_hint.fallback_reason_code or "high_risk_hint",
)
def _handle_low_confidence(self, intent_match: IntentMatch) -> PolicyRouterResult:
"""
[AC-IDMP-16] Handle low confidence by falling back to micro_flow or fixed.
"""
if intent_match.flow_id:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
fallback_reason_code="low_confidence",
target_flow_id=intent_match.flow_id,
)
if intent_match.fixed_reply:
return PolicyRouterResult(
mode=ExecutionMode.FIXED,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
fallback_reason_code="low_confidence",
fixed_reply=intent_match.fixed_reply,
)
return PolicyRouterResult(
mode=ExecutionMode.FIXED,
fallback_reason_code="low_confidence_no_flow",
)
def get_active_high_risk_set(self) -> list[HighRiskScenario]:
"""[AC-IDMP-20] Get active high-risk scenario set."""
return self._high_risk_scenarios
def route_with_high_risk_check(
self,
user_message: str,
high_risk_check_result: "HighRiskCheckResult | None",
session_mode: str = "BOT_ACTIVE",
feature_flags: FeatureFlags | None = None,
intent_match: IntentMatch | None = None,
intent_hint: "IntentHintOutput | None" = None,
context: dict[str, Any] | None = None,
) -> PolicyRouterResult:
"""
[AC-IDMP-05, AC-IDMP-20] Route with high_risk_check result.
高风险优先于普通意图路由
1. 如果 high_risk_check 匹配直接返回高风险路由结果
2. 否则继续正常的路由决策
Args:
user_message: User input message
high_risk_check_result: Result from high_risk_check tool
session_mode: Current session mode (BOT_ACTIVE/HUMAN_ACTIVE)
feature_flags: Feature flags for grayscale control
intent_match: Intent match result if available
intent_hint: Soft signal from intent_hint tool (optional)
context: Additional context for decision
Returns:
PolicyRouterResult with decided mode and metadata
"""
if high_risk_check_result and high_risk_check_result.matched:
logger.info(
f"[AC-IDMP-05, AC-IDMP-20] High-risk check matched: "
f"scenario={high_risk_check_result.risk_scenario}, "
f"rule_id={high_risk_check_result.rule_id}"
)
return self._handle_high_risk_check_result(
high_risk_check_result, intent_match
)
return self.route(
user_message=user_message,
session_mode=session_mode,
feature_flags=feature_flags,
intent_match=intent_match,
intent_hint=intent_hint,
context=context,
)
def _handle_high_risk_check_result(
self,
high_risk_result: "HighRiskCheckResult",
intent_match: IntentMatch | None,
) -> PolicyRouterResult:
"""
[AC-IDMP-05] Handle high_risk_check result.
高风险检测结果优先于普通意图路由
"""
recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW
risk_scenario = high_risk_result.risk_scenario
if recommended_mode == ExecutionMode.TRANSFER:
transfer_msg = "正在为您转接人工客服..."
if risk_scenario:
if risk_scenario == HighRiskScenario.COMPLAINT_ESCALATION:
transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..."
elif risk_scenario == HighRiskScenario.REFUND:
transfer_msg = "您的退款请求需要人工处理,正在为您转接..."
return PolicyRouterResult(
mode=ExecutionMode.TRANSFER,
high_risk_triggered=True,
transfer_message=transfer_msg,
fallback_reason_code=high_risk_result.rule_id,
)
if intent_match and intent_match.flow_id:
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
intent=intent_match.intent_name,
confidence=intent_match.confidence,
high_risk_triggered=True,
target_flow_id=intent_match.flow_id,
)
return PolicyRouterResult(
mode=ExecutionMode.MICRO_FLOW,
high_risk_triggered=True,
fallback_reason_code=high_risk_result.rule_id
or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
)

View File

@ -0,0 +1,289 @@
"""
Runtime Observer for Mid Platform.
[AC-MARH-12] 运行时观测闭环
汇总 guardrailinterruptkb_hittimeoutssegment_stats 等观测字段
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import (
ExecutionMode,
MetricsSnapshot,
SegmentStats,
TimeoutProfile,
ToolCallTrace,
TraceInfo,
)
logger = logging.getLogger(__name__)
@dataclass
class RuntimeContext:
"""运行时上下文。"""
tenant_id: str = ""
session_id: str = ""
request_id: str = ""
generation_id: str = ""
mode: ExecutionMode = ExecutionMode.AGENT
intent: str | None = None
guardrail_triggered: bool = False
guardrail_rule_id: str | None = None
interrupt_consumed: bool = False
kb_tool_called: bool = False
kb_hit: bool = False
fallback_reason_code: str | None = None
react_iterations: int = 0
tool_calls: list[ToolCallTrace] = field(default_factory=list)
timeout_profile: TimeoutProfile | None = None
segment_stats: SegmentStats | None = None
metrics_snapshot: MetricsSnapshot | None = None
start_time: float = field(default_factory=time.time)
def to_trace_info(self) -> TraceInfo:
"""转换为 TraceInfo。"""
return TraceInfo(
mode=self.mode,
intent=self.intent,
request_id=self.request_id,
generation_id=self.generation_id,
guardrail_triggered=self.guardrail_triggered,
guardrail_rule_id=self.guardrail_rule_id,
interrupt_consumed=self.interrupt_consumed,
kb_tool_called=self.kb_tool_called,
kb_hit=self.kb_hit,
fallback_reason_code=self.fallback_reason_code,
react_iterations=self.react_iterations,
timeout_profile=self.timeout_profile,
segment_stats=self.segment_stats,
metrics_snapshot=self.metrics_snapshot,
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls if self.tool_calls else None,
)
class RuntimeObserver:
"""
[AC-MARH-12] 运行时观测器
Features:
- 汇总 guardrailinterruptkb_hittimeoutssegment_stats
- 生成完整 TraceInfo
- 记录观测日志
"""
def __init__(self):
self._contexts: dict[str, RuntimeContext] = {}
def start_observation(
self,
tenant_id: str,
session_id: str,
request_id: str,
generation_id: str,
) -> RuntimeContext:
"""
[AC-MARH-12] 开始观测
Args:
tenant_id: 租户 ID
session_id: 会话 ID
request_id: 请求 ID
generation_id: 生成 ID
Returns:
RuntimeContext 实例
"""
ctx = RuntimeContext(
tenant_id=tenant_id,
session_id=session_id,
request_id=request_id,
generation_id=generation_id,
)
self._contexts[request_id] = ctx
logger.info(
f"[AC-MARH-12] Observation started: request_id={request_id}, "
f"session_id={session_id}"
)
return ctx
def get_context(self, request_id: str) -> RuntimeContext | None:
"""获取观测上下文。"""
return self._contexts.get(request_id)
def update_mode(
self,
request_id: str,
mode: ExecutionMode,
intent: str | None = None,
) -> None:
"""更新执行模式。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.mode = mode
ctx.intent = intent
def record_guardrail(
self,
request_id: str,
triggered: bool,
rule_id: str | None = None,
) -> None:
"""[AC-MARH-12] 记录护栏触发。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.guardrail_triggered = triggered
ctx.guardrail_rule_id = rule_id
logger.info(
f"[AC-MARH-12] Guardrail recorded: request_id={request_id}, "
f"triggered={triggered}, rule_id={rule_id}"
)
def record_interrupt(
self,
request_id: str,
consumed: bool,
) -> None:
"""[AC-MARH-12] 记录中断处理。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.interrupt_consumed = consumed
logger.info(
f"[AC-MARH-12] Interrupt recorded: request_id={request_id}, "
f"consumed={consumed}"
)
def record_kb(
self,
request_id: str,
tool_called: bool,
hit: bool,
fallback_reason: str | None = None,
) -> None:
"""[AC-MARH-12] 记录 KB 检索。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.kb_tool_called = tool_called
ctx.kb_hit = hit
if fallback_reason:
ctx.fallback_reason_code = fallback_reason
logger.info(
f"[AC-MARH-12] KB recorded: request_id={request_id}, "
f"tool_called={tool_called}, hit={hit}, fallback={fallback_reason}"
)
def record_react(
self,
request_id: str,
iterations: int,
tool_calls: list[ToolCallTrace] | None = None,
) -> None:
"""[AC-MARH-12] 记录 ReAct 循环。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.react_iterations = iterations
if tool_calls:
ctx.tool_calls = tool_calls
def record_timeout_profile(
self,
request_id: str,
profile: TimeoutProfile,
) -> None:
"""[AC-MARH-12] 记录超时配置。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.timeout_profile = profile
def record_segment_stats(
self,
request_id: str,
stats: SegmentStats,
) -> None:
"""[AC-MARH-12] 记录分段统计。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.segment_stats = stats
def record_metrics(
self,
request_id: str,
metrics: MetricsSnapshot,
) -> None:
"""[AC-MARH-12] 记录指标快照。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.metrics_snapshot = metrics
def set_fallback_reason(
self,
request_id: str,
reason: str,
) -> None:
"""设置降级原因。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.fallback_reason_code = reason
def end_observation(
self,
request_id: str,
) -> TraceInfo:
"""
[AC-MARH-12] 结束观测并生成 TraceInfo
Args:
request_id: 请求 ID
Returns:
完整的 TraceInfo
"""
ctx = self._contexts.get(request_id)
if not ctx:
logger.warning(f"[AC-MARH-12] Context not found: {request_id}")
return TraceInfo(mode=ExecutionMode.FIXED)
duration_ms = int((time.time() - ctx.start_time) * 1000)
trace_info = ctx.to_trace_info()
logger.info(
f"[AC-MARH-12] Observation ended: request_id={request_id}, "
f"mode={ctx.mode.value}, duration_ms={duration_ms}, "
f"guardrail={ctx.guardrail_triggered}, kb_hit={ctx.kb_hit}, "
f"segments={ctx.segment_stats.segment_count if ctx.segment_stats else 0}"
)
if request_id in self._contexts:
del self._contexts[request_id]
return trace_info
_runtime_observer: RuntimeObserver | None = None
def get_runtime_observer() -> RuntimeObserver:
"""获取或创建 RuntimeObserver 实例。"""
global _runtime_observer
if _runtime_observer is None:
_runtime_observer = RuntimeObserver()
return _runtime_observer

View File

@ -0,0 +1,282 @@
"""
Segment Humanizer for Mid Platform.
[AC-MARH-10] 分段策略组件语义/长度切分
[AC-MARH-11] delay 策略租户化配置
将文本按语义/长度切分为 segments并生成拟人化 delay
"""
import logging
import re
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import Segment, SegmentStats
logger = logging.getLogger(__name__)
DEFAULT_MIN_DELAY_MS = 50
DEFAULT_MAX_DELAY_MS = 500
DEFAULT_SEGMENT_MIN_LENGTH = 10
DEFAULT_SEGMENT_MAX_LENGTH = 200
@dataclass
class HumanizeConfig:
"""拟人化配置。"""
enabled: bool = True
min_delay_ms: int = DEFAULT_MIN_DELAY_MS
max_delay_ms: int = DEFAULT_MAX_DELAY_MS
length_bucket_strategy: str = "simple"
segment_min_length: int = DEFAULT_SEGMENT_MIN_LENGTH
segment_max_length: int = DEFAULT_SEGMENT_MAX_LENGTH
def to_dict(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"min_delay_ms": self.min_delay_ms,
"max_delay_ms": self.max_delay_ms,
"length_bucket_strategy": self.length_bucket_strategy,
"segment_min_length": self.segment_min_length,
"segment_max_length": self.segment_max_length,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "HumanizeConfig":
return cls(
enabled=data.get("enabled", True),
min_delay_ms=data.get("min_delay_ms", DEFAULT_MIN_DELAY_MS),
max_delay_ms=data.get("max_delay_ms", DEFAULT_MAX_DELAY_MS),
length_bucket_strategy=data.get("length_bucket_strategy", "simple"),
segment_min_length=data.get("segment_min_length", DEFAULT_SEGMENT_MIN_LENGTH),
segment_max_length=data.get("segment_max_length", DEFAULT_SEGMENT_MAX_LENGTH),
)
@dataclass
class LengthBucket:
"""长度区间与对应 delay。"""
min_length: int
max_length: int
delay_ms: int
DEFAULT_LENGTH_BUCKETS = [
LengthBucket(min_length=0, max_length=20, delay_ms=100),
LengthBucket(min_length=20, max_length=50, delay_ms=200),
LengthBucket(min_length=50, max_length=100, delay_ms=300),
LengthBucket(min_length=100, max_length=200, delay_ms=400),
LengthBucket(min_length=200, max_length=10000, delay_ms=500),
]
class SegmentHumanizer:
"""
[AC-MARH-10, AC-MARH-11] 分段拟人化组件
Features:
- 按语义/长度切分文本
- 生成拟人化 delay
- 支持租户配置覆盖
- 输出 segment_stats 统计
"""
def __init__(
self,
config: HumanizeConfig | None = None,
length_buckets: list[LengthBucket] | None = None,
):
self._config = config or HumanizeConfig()
self._length_buckets = length_buckets or DEFAULT_LENGTH_BUCKETS
def humanize(
self,
text: str,
override_config: HumanizeConfig | None = None,
) -> tuple[list[Segment], SegmentStats]:
"""
[AC-MARH-10] 将文本转换为拟人化分段
Args:
text: 输入文本
override_config: 租户覆盖配置
Returns:
Tuple of (segments, segment_stats)
"""
config = override_config or self._config
if not config.enabled:
segments = [Segment(
segment_id=str(uuid.uuid4()),
text=text,
delay_after=0,
)]
stats = SegmentStats(
segment_count=1,
avg_segment_length=len(text),
humanize_strategy="disabled",
)
return segments, stats
raw_segments = self._split_text(text, config)
segments = []
for i, seg_text in enumerate(raw_segments):
is_last = i == len(raw_segments) - 1
delay_after = 0 if is_last else self._calculate_delay(seg_text, config)
segments.append(Segment(
segment_id=str(uuid.uuid4()),
text=seg_text,
delay_after=delay_after,
))
total_length = sum(len(s.text) for s in segments)
avg_length = total_length / len(segments) if segments else 0.0
stats = SegmentStats(
segment_count=len(segments),
avg_segment_length=avg_length,
humanize_strategy=config.length_bucket_strategy,
)
logger.info(
f"[AC-MARH-10] Humanized text: segments={len(segments)}, "
f"avg_length={avg_length:.1f}, strategy={config.length_bucket_strategy}"
)
return segments, stats
def _split_text(
self,
text: str,
config: HumanizeConfig,
) -> list[str]:
"""切分文本。"""
if config.length_bucket_strategy == "semantic":
return self._split_semantic(text, config)
else:
return self._split_simple(text, config)
def _split_simple(
self,
text: str,
config: HumanizeConfig,
) -> list[str]:
"""简单切分:按段落。"""
paragraphs = re.split(r'\n\s*\n', text.strip())
segments = []
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(para) <= config.segment_max_length:
segments.append(para)
else:
sub_segments = self._split_by_length(para, config.segment_max_length)
segments.extend(sub_segments)
if not segments:
segments = [text.strip()]
return [s for s in segments if s.strip()]
def _split_semantic(
self,
text: str,
config: HumanizeConfig,
) -> list[str]:
"""语义切分:按句子边界。"""
sentence_endings = re.compile(r'([。!?.!?]+)')
parts = sentence_endings.split(text.strip())
sentences = []
current = ""
for i, part in enumerate(parts):
current += part
if sentence_endings.match(part):
sentences.append(current.strip())
current = ""
if current.strip():
sentences.append(current.strip())
if not sentences:
sentences = [text.strip()]
segments = []
current_segment = ""
for sentence in sentences:
if len(current_segment) + len(sentence) <= config.segment_max_length:
current_segment += sentence
else:
if current_segment:
segments.append(current_segment)
current_segment = sentence
if current_segment:
segments.append(current_segment)
return [s for s in segments if s.strip()]
def _split_by_length(
self,
text: str,
max_length: int,
) -> list[str]:
"""按长度切分。"""
segments = []
remaining = text
while remaining:
if len(remaining) <= max_length:
segments.append(remaining.strip())
break
split_pos = max_length
for i in range(max_length - 1, max(0, max_length - 20), -1):
if remaining[i] in ',;: ':
split_pos = i + 1
break
segments.append(remaining[:split_pos].strip())
remaining = remaining[split_pos:]
return [s for s in segments if s.strip()]
def _calculate_delay(
self,
text: str,
config: HumanizeConfig,
) -> int:
"""[AC-MARH-11] 计算拟人化 delay。"""
text_length = len(text)
for bucket in self._length_buckets:
if bucket.min_length <= text_length < bucket.max_length:
delay = bucket.delay_ms
return max(config.min_delay_ms, min(delay, config.max_delay_ms))
return config.min_delay_ms
def get_config(self) -> HumanizeConfig:
"""获取当前配置。"""
return self._config
_segment_humanizer: SegmentHumanizer | None = None
def get_segment_humanizer(
config: HumanizeConfig | None = None,
) -> SegmentHumanizer:
"""获取或创建 SegmentHumanizer 实例。"""
global _segment_humanizer
if _segment_humanizer is None:
_segment_humanizer = SegmentHumanizer(config=config)
return _segment_humanizer

View File

@ -0,0 +1,166 @@
"""
Timeout Governor for Mid Platform.
[AC-IDMP-12] Timeout governance: per-tool <= 60s, end-to-end <= 180s.
[AC-MARH-08, AC-MARH-09] 超时口径统一单工具 <= 60000ms全链路 <= 180000ms
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, TypeVar
from app.models.mid.schemas import TimeoutProfile
logger = logging.getLogger(__name__)
DEFAULT_PER_TOOL_TIMEOUT_MS = 30000
DEFAULT_END_TO_END_TIMEOUT_MS = 120000
DEFAULT_LLM_TIMEOUT_MS = 60000
T = TypeVar("T")
@dataclass
class TimeoutResult:
"""Result of a timeout-governed operation."""
success: bool
result: Any = None
error: str | None = None
duration_ms: int = 0
timed_out: bool = False
class TimeoutGovernor:
"""
[AC-IDMP-12] Timeout governor for tool calls and end-to-end execution.
[AC-MARH-08, AC-MARH-09] 超时口径统一
Constraints:
- Per-tool timeout: <= 60000ms (60s)
- LLM timeout: <= 120000ms (120s)
- End-to-end timeout: <= 180000ms (180s)
"""
def __init__(
self,
per_tool_timeout_ms: int = DEFAULT_PER_TOOL_TIMEOUT_MS,
end_to_end_timeout_ms: int = DEFAULT_END_TO_END_TIMEOUT_MS,
llm_timeout_ms: int = DEFAULT_LLM_TIMEOUT_MS,
):
self._per_tool_timeout_ms = min(per_tool_timeout_ms, 60000)
self._end_to_end_timeout_ms = min(end_to_end_timeout_ms, 180000)
self._llm_timeout_ms = min(llm_timeout_ms, 120000)
@property
def per_tool_timeout_seconds(self) -> float:
"""Per-tool timeout in seconds."""
return self._per_tool_timeout_ms / 1000.0
@property
def llm_timeout_seconds(self) -> float:
"""LLM call timeout in seconds."""
return self._llm_timeout_ms / 1000.0
@property
def end_to_end_timeout_seconds(self) -> float:
"""End-to-end timeout in seconds."""
return self._end_to_end_timeout_ms / 1000.0
@property
def profile(self) -> TimeoutProfile:
"""Get current timeout profile."""
return TimeoutProfile(
per_tool_timeout_ms=self._per_tool_timeout_ms,
llm_timeout_ms=self._llm_timeout_ms,
end_to_end_timeout_ms=self._end_to_end_timeout_ms,
)
async def execute_with_timeout(
self,
coro: Callable[[], T],
timeout_ms: int | None = None,
operation_name: str = "operation",
) -> TimeoutResult:
"""
[AC-MARH-08, AC-MARH-09] Execute a coroutine with timeout.
Args:
coro: Coroutine to execute
timeout_ms: Timeout in milliseconds (defaults to per-tool timeout)
operation_name: Name for logging
Returns:
TimeoutResult with success status and result/error
"""
import time
timeout_ms = timeout_ms or self._per_tool_timeout_ms
timeout_seconds = timeout_ms / 1000.0
start_time = time.time()
try:
result = await asyncio.wait_for(coro(), timeout=timeout_seconds)
duration_ms = int((time.time() - start_time) * 1000)
logger.debug(
f"[AC-MARH-08] {operation_name} completed in {duration_ms}ms"
)
return TimeoutResult(
success=True,
result=result,
duration_ms=duration_ms,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-MARH-08] {operation_name} timed out after {duration_ms}ms "
f"(limit: {timeout_ms}ms)"
)
return TimeoutResult(
success=False,
error=f"Timeout after {timeout_ms}ms",
duration_ms=duration_ms,
timed_out=True,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-MARH-08] {operation_name} failed: {e}"
)
return TimeoutResult(
success=False,
error=str(e),
duration_ms=duration_ms,
)
async def execute_tool(
self,
tool_name: str,
tool_func: Callable[[], T],
) -> TimeoutResult:
"""
[AC-MARH-08] Execute a tool call with per-tool timeout.
"""
return await self.execute_with_timeout(
coro=tool_func,
timeout_ms=self._per_tool_timeout_ms,
operation_name=f"tool:{tool_name}",
)
async def execute_e2e(
self,
coro: Callable[[], T],
operation_name: str = "e2e",
) -> TimeoutResult:
"""
[AC-MARH-09] Execute with end-to-end timeout.
"""
return await self.execute_with_timeout(
coro=coro,
timeout_ms=self._end_to_end_timeout_ms,
operation_name=operation_name,
)

View File

@ -0,0 +1,324 @@
"""
Tool Call Recorder for Mid Platform.
[AC-IDMP-15] 工具调用结构化记录
Reference:
- spec/intent-driven-mid-platform/openapi.provider.yaml - ToolCallTrace
- spec/intent-driven-mid-platform/requirements.md AC-IDMP-15
"""
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.models.mid.tool_trace import (
ToolCallBuilder,
ToolCallStatus,
ToolCallTrace,
ToolType,
)
logger = logging.getLogger(__name__)
@dataclass
class ToolCallStatistics:
"""
工具调用统计信息
"""
total_calls: int = 0
success_calls: int = 0
timeout_calls: int = 0
error_calls: int = 0
rejected_calls: int = 0
total_duration_ms: int = 0
avg_duration_ms: float = 0.0
max_duration_ms: int = 0
min_duration_ms: int = 0
def update(self, trace: ToolCallTrace) -> None:
self.total_calls += 1
self.total_duration_ms += trace.duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
if trace.duration_ms > self.max_duration_ms:
self.max_duration_ms = trace.duration_ms
if self.min_duration_ms == 0 or trace.duration_ms < self.min_duration_ms:
self.min_duration_ms = trace.duration_ms
if trace.status == ToolCallStatus.OK:
self.success_calls += 1
elif trace.status == ToolCallStatus.TIMEOUT:
self.timeout_calls += 1
elif trace.status == ToolCallStatus.REJECTED:
self.rejected_calls += 1
else:
self.error_calls += 1
class ToolCallRecorder:
"""
[AC-IDMP-15] 工具调用记录器
功能
1. 记录每次工具调用的完整信息参数摘要耗时状态错误码
2. 支持敏感参数脱敏
3. 提供统计信息
记录字段
- tool_name: 工具名称
- tool_type: 工具类型 (internal | mcp)
- registry_version: 注册表版本
- auth_applied: 是否应用鉴权
- duration_ms: 调用耗时
- status: 调用状态 (ok | timeout | error | rejected)
- error_code: 错误码
- args_digest: 参数摘要脱敏
- result_digest: 结果摘要
"""
def __init__(self, max_traces_per_session: int = 100):
self._max_traces_per_session = max_traces_per_session
self._traces: dict[str, list[ToolCallTrace]] = defaultdict(list)
self._statistics: dict[str, ToolCallStatistics] = defaultdict(ToolCallStatistics)
def start_trace(
self,
tool_name: str,
tool_type: ToolType = ToolType.INTERNAL,
) -> ToolCallBuilder:
"""
开始记录一次工具调用
Args:
tool_name: 工具名称
tool_type: 工具类型
Returns:
ToolCallBuilder: 构建器用于逐步记录调用信息
"""
return ToolCallBuilder(
tool_name=tool_name,
tool_type=tool_type,
)
def record(
self,
session_id: str,
trace: ToolCallTrace,
) -> None:
"""
记录一次工具调用的完整信息
Args:
session_id: 会话ID
trace: 工具调用追踪记录
"""
session_traces = self._traces[session_id]
session_traces.append(trace)
if len(session_traces) > self._max_traces_per_session:
session_traces.pop(0)
self._statistics[trace.tool_name].update(trace)
logger.info(
f"[AC-IDMP-15] Tool call recorded: tool={trace.tool_name}, "
f"type={trace.tool_type.value}, duration_ms={trace.duration_ms}, "
f"status={trace.status.value}, session={session_id}"
)
def record_success(
self,
session_id: str,
tool_name: str,
tool_type: ToolType,
duration_ms: int,
args: Any = None,
result: Any = None,
registry_version: str | None = None,
auth_applied: bool = False,
) -> ToolCallTrace:
"""
记录成功的工具调用便捷方法
"""
trace = ToolCallTrace(
tool_name=tool_name,
tool_type=tool_type,
duration_ms=duration_ms,
status=ToolCallStatus.OK,
registry_version=registry_version,
auth_applied=auth_applied,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
result_digest=ToolCallTrace.compute_digest(result) if result else None,
)
self.record(session_id, trace)
return trace
def record_timeout(
self,
session_id: str,
tool_name: str,
tool_type: ToolType,
duration_ms: int,
args: Any = None,
registry_version: str | None = None,
auth_applied: bool = False,
) -> ToolCallTrace:
"""
记录超时的工具调用便捷方法
"""
trace = ToolCallTrace(
tool_name=tool_name,
tool_type=tool_type,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="TIMEOUT",
registry_version=registry_version,
auth_applied=auth_applied,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
)
self.record(session_id, trace)
return trace
def record_error(
self,
session_id: str,
tool_name: str,
tool_type: ToolType,
duration_ms: int,
error_code: str,
error_message: str | None = None,
args: Any = None,
registry_version: str | None = None,
auth_applied: bool = False,
) -> ToolCallTrace:
"""
记录错误的工具调用便捷方法
"""
trace = ToolCallTrace(
tool_name=tool_name,
tool_type=tool_type,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code=error_code,
registry_version=registry_version,
auth_applied=auth_applied,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
)
self.record(session_id, trace)
return trace
def record_rejected(
self,
session_id: str,
tool_name: str,
tool_type: ToolType,
reason: str,
args: Any = None,
registry_version: str | None = None,
) -> ToolCallTrace:
"""
记录被拒绝的工具调用便捷方法
"""
trace = ToolCallTrace(
tool_name=tool_name,
tool_type=tool_type,
duration_ms=0,
status=ToolCallStatus.REJECTED,
error_code=reason,
registry_version=registry_version,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
)
self.record(session_id, trace)
return trace
def get_traces(self, session_id: str) -> list[ToolCallTrace]:
"""
获取指定会话的所有工具调用记录
Args:
session_id: 会话ID
Returns:
list[ToolCallTrace]: 工具调用记录列表
"""
return self._traces.get(session_id, [])
def get_statistics(self, tool_name: str | None = None) -> dict[str, Any]:
"""
获取工具调用统计信息
Args:
tool_name: 工具名称可选不提供则返回所有统计
Returns:
dict: 统计信息
"""
if tool_name:
stats = self._statistics.get(tool_name)
if stats:
return {
"tool_name": tool_name,
"total_calls": stats.total_calls,
"success_rate": stats.success_calls / stats.total_calls if stats.total_calls > 0 else 0,
"timeout_rate": stats.timeout_calls / stats.total_calls if stats.total_calls > 0 else 0,
"error_rate": stats.error_calls / stats.total_calls if stats.total_calls > 0 else 0,
"avg_duration_ms": stats.avg_duration_ms,
"max_duration_ms": stats.max_duration_ms,
"min_duration_ms": stats.min_duration_ms,
}
return {}
return {
name: {
"total_calls": stats.total_calls,
"success_rate": stats.success_calls / stats.total_calls if stats.total_calls > 0 else 0,
"avg_duration_ms": stats.avg_duration_ms,
}
for name, stats in self._statistics.items()
}
def clear_session(self, session_id: str) -> int:
"""
清除指定会话的记录
Args:
session_id: 会话ID
Returns:
int: 清除的记录数
"""
if session_id in self._traces:
count = len(self._traces[session_id])
del self._traces[session_id]
return count
return 0
def to_trace_info_format(self, session_id: str) -> list[dict[str, Any]]:
"""
转换为 TraceInfo.tool_calls 格式
用于输出到 DialogueResponse.trace.tool_calls
Args:
session_id: 会话ID
Returns:
list[dict]: 符合 OpenAPI 格式的工具调用列表
"""
traces = self.get_traces(session_id)
return [trace.to_dict() for trace in traces]
_recorder: ToolCallRecorder | None = None
def get_tool_call_recorder() -> ToolCallRecorder:
"""获取全局工具调用记录器实例"""
global _recorder
if _recorder is None:
_recorder = ToolCallRecorder()
return _recorder

View File

@ -0,0 +1,337 @@
"""
Tool Registry for Mid Platform.
[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance.
"""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine
from app.models.mid.schemas import (
ToolCallStatus,
ToolCallTrace,
ToolType,
)
from app.services.mid.timeout_governor import TimeoutGovernor
logger = logging.getLogger(__name__)
@dataclass
class ToolDefinition:
"""Tool definition for registry."""
name: str
description: str
tool_type: ToolType = ToolType.INTERNAL
version: str = "1.0.0"
enabled: bool = True
auth_required: bool = False
timeout_ms: int = 2000
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolExecutionResult:
"""Tool execution result."""
success: bool
output: Any = None
error: str | None = None
duration_ms: int = 0
auth_applied: bool = False
registry_version: str | None = None
class ToolRegistry:
"""
[AC-IDMP-19] Unified tool registry for governance.
Features:
- Tool registration with metadata
- Auth policy enforcement
- Timeout governance
- Version management
- Enable/disable control
"""
def __init__(
self,
timeout_governor: TimeoutGovernor | None = None,
):
self._tools: dict[str, ToolDefinition] = {}
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._version = "1.0.0"
@property
def version(self) -> str:
"""Get registry version."""
return self._version
def register(
self,
name: str,
description: str,
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]],
tool_type: ToolType = ToolType.INTERNAL,
version: str = "1.0.0",
auth_required: bool = False,
timeout_ms: int = 2000,
enabled: bool = True,
metadata: dict[str, Any] | None = None,
) -> ToolDefinition:
"""
[AC-IDMP-19] Register a tool.
Args:
name: Tool name (unique identifier)
description: Tool description
handler: Async handler function
tool_type: Tool type (internal/mcp)
version: Tool version
auth_required: Whether auth is required
timeout_ms: Tool-specific timeout
enabled: Whether tool is enabled
metadata: Additional metadata
Returns:
ToolDefinition for the registered tool
"""
if name in self._tools:
logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}")
tool = ToolDefinition(
name=name,
description=description,
tool_type=tool_type,
version=version,
enabled=enabled,
auth_required=auth_required,
timeout_ms=min(timeout_ms, 2000),
handler=handler,
metadata=metadata or {},
)
self._tools[name] = tool
logger.info(
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
f"version={version}, auth_required={auth_required}"
)
return tool
def unregister(self, name: str) -> bool:
"""Unregister a tool."""
if name in self._tools:
del self._tools[name]
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
return True
return False
def get_tool(self, name: str) -> ToolDefinition | None:
"""Get tool definition by name."""
return self._tools.get(name)
def list_tools(
self,
tool_type: ToolType | None = None,
enabled_only: bool = True,
) -> list[ToolDefinition]:
"""List registered tools, optionally filtered."""
tools = list(self._tools.values())
if tool_type:
tools = [t for t in tools if t.tool_type == tool_type]
if enabled_only:
tools = [t for t in tools if t.enabled]
return tools
def enable_tool(self, name: str) -> bool:
"""Enable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = True
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
return True
return False
def disable_tool(self, name: str) -> bool:
"""Disable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = False
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
return True
return False
async def execute(
self,
tool_name: str,
args: dict[str, Any],
auth_context: dict[str, Any] | None = None,
) -> ToolExecutionResult:
"""
[AC-IDMP-19] Execute a tool with governance.
Args:
tool_name: Tool name to execute
args: Tool arguments
auth_context: Authentication context
Returns:
ToolExecutionResult with output and metadata
"""
start_time = time.time()
tool = self._tools.get(tool_name)
if not tool:
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool not found: {tool_name}",
duration_ms=0,
)
if not tool.enabled:
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool disabled: {tool_name}",
duration_ms=0,
registry_version=tool.version,
)
auth_applied = False
if tool.auth_required:
if not auth_context:
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
return ToolExecutionResult(
success=False,
error="Authentication required",
duration_ms=int((time.time() - start_time) * 1000),
auth_applied=False,
registry_version=tool.version,
)
auth_applied = True
try:
timeout_seconds = tool.timeout_ms / 1000.0
result = await asyncio.wait_for(
tool.handler(**args) if tool.handler else asyncio.sleep(0),
timeout=timeout_seconds,
)
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
f"duration_ms={duration_ms}, success=True"
)
return ToolExecutionResult(
success=True,
output=result,
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
f"duration_ms={duration_ms}"
)
return ToolExecutionResult(
success=False,
error=f"Tool timeout after {tool.timeout_ms}ms",
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
)
return ToolExecutionResult(
success=False,
error=str(e),
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
)
def create_trace(
self,
tool_name: str,
result: ToolExecutionResult,
args_digest: str | None = None,
) -> ToolCallTrace:
"""
[AC-IDMP-19] Create ToolCallTrace from execution result.
"""
tool = self._tools.get(tool_name)
return ToolCallTrace(
tool_name=tool_name,
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
registry_version=result.registry_version,
auth_applied=result.auth_applied,
duration_ms=result.duration_ms,
status=ToolCallStatus.OK if result.success else (
ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower()
else ToolCallStatus.ERROR
),
error_code=result.error if not result.success else None,
args_digest=args_digest,
result_digest=str(result.output)[:100] if result.output else None,
)
def get_governance_report(self) -> dict[str, Any]:
"""Get governance report for all tools."""
return {
"registry_version": self._version,
"total_tools": len(self._tools),
"enabled_tools": sum(1 for t in self._tools.values() if t.enabled),
"disabled_tools": sum(1 for t in self._tools.values() if not t.enabled),
"auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required),
"mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP),
"internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL),
"tools": [
{
"name": t.name,
"type": t.tool_type.value,
"version": t.version,
"enabled": t.enabled,
"auth_required": t.auth_required,
"timeout_ms": t.timeout_ms,
}
for t in self._tools.values()
],
}
_registry: ToolRegistry | None = None
def get_tool_registry() -> ToolRegistry:
"""Get global tool registry instance."""
global _registry
if _registry is None:
_registry = ToolRegistry()
return _registry
def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry:
"""Initialize and return tool registry."""
global _registry
_registry = ToolRegistry(timeout_governor=timeout_governor)
return _registry

View File

@ -0,0 +1,269 @@
"""
Trace Logger for Mid Platform.
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace collection and audit logging.
Audit Fields:
- session_id, request_id, generation_id
- mode, intent, tool_calls
- guardrail_triggered, guardrail_rule_id
- interrupt_consumed
"""
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import (
ExecutionMode,
ToolCallTrace,
TraceInfo,
)
logger = logging.getLogger(__name__)
@dataclass
class AuditRecord:
"""[AC-MARH-12] Audit record for database persistence."""
tenant_id: str
session_id: str
request_id: str
generation_id: str
mode: ExecutionMode
intent: str | None = None
tool_calls: list[dict[str, Any]] = field(default_factory=list)
guardrail_triggered: bool = False
guardrail_rule_id: str | None = None
interrupt_consumed: bool = False
fallback_reason_code: str | None = None
react_iterations: int = 0
latency_ms: int = 0
created_at: str | None = None
def to_dict(self) -> dict[str, Any]:
return {
"tenant_id": self.tenant_id,
"session_id": self.session_id,
"request_id": self.request_id,
"generation_id": self.generation_id,
"mode": self.mode.value,
"intent": self.intent,
"tool_calls": self.tool_calls,
"guardrail_triggered": self.guardrail_triggered,
"guardrail_rule_id": self.guardrail_rule_id,
"interrupt_consumed": self.interrupt_consumed,
"fallback_reason_code": self.fallback_reason_code,
"react_iterations": self.react_iterations,
"latency_ms": self.latency_ms,
"created_at": self.created_at,
}
class TraceLogger:
"""
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace logger for observability and audit.
Features:
- Request-scoped trace context
- Tool call tracing
- Guardrail event logging with rule_id
- Interrupt consumption tracking
- Audit record generation
"""
def __init__(self):
self._traces: dict[str, TraceInfo] = {}
self._audit_records: list[AuditRecord] = []
def start_trace(
self,
tenant_id: str,
session_id: str,
request_id: str | None = None,
generation_id: str | None = None,
) -> TraceInfo:
"""
[AC-MARH-12] Start a new trace context.
Args:
tenant_id: Tenant ID
session_id: Session ID
request_id: Request ID (auto-generated if not provided)
generation_id: Generation ID for interrupt handling
Returns:
TraceInfo for the new trace
"""
request_id = request_id or str(uuid.uuid4())
generation_id = generation_id or str(uuid.uuid4())
trace = TraceInfo(
mode=ExecutionMode.AGENT,
request_id=request_id,
generation_id=generation_id,
)
self._traces[request_id] = trace
logger.info(
f"[AC-MARH-12] Trace started: request_id={request_id}, "
f"session_id={session_id}, generation_id={generation_id}"
)
return trace
def get_trace(self, request_id: str) -> TraceInfo | None:
"""Get trace by request ID."""
return self._traces.get(request_id)
def update_trace(
self,
request_id: str,
mode: ExecutionMode | None = None,
intent: str | None = None,
guardrail_triggered: bool | None = None,
guardrail_rule_id: str | None = None,
interrupt_consumed: bool | None = None,
fallback_reason_code: str | None = None,
react_iterations: int | None = None,
tool_calls: list[ToolCallTrace] | None = None,
) -> TraceInfo | None:
"""
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Update trace with execution details.
"""
trace = self._traces.get(request_id)
if not trace:
logger.warning(f"[AC-MARH-12] Trace not found: {request_id}")
return None
if mode is not None:
trace.mode = mode
if intent is not None:
trace.intent = intent
if guardrail_triggered is not None:
trace.guardrail_triggered = guardrail_triggered
if guardrail_rule_id is not None:
trace.guardrail_rule_id = guardrail_rule_id
if interrupt_consumed is not None:
trace.interrupt_consumed = interrupt_consumed
if fallback_reason_code is not None:
trace.fallback_reason_code = fallback_reason_code
if react_iterations is not None:
trace.react_iterations = react_iterations
if tool_calls is not None:
trace.tool_calls = tool_calls
return trace
def add_tool_call(
self,
request_id: str,
tool_call: ToolCallTrace,
) -> None:
"""
[AC-MARH-12] Add tool call trace to request.
"""
trace = self._traces.get(request_id)
if not trace:
logger.warning(f"[AC-MARH-12] Trace not found for tool call: {request_id}")
return
if trace.tool_calls is None:
trace.tool_calls = []
trace.tool_calls.append(tool_call)
if trace.tools_used is None:
trace.tools_used = []
if tool_call.tool_name not in trace.tools_used:
trace.tools_used.append(tool_call.tool_name)
logger.debug(
f"[AC-MARH-12] Tool call recorded: request_id={request_id}, "
f"tool={tool_call.tool_name}, status={tool_call.status.value}"
)
def end_trace(
self,
request_id: str,
tenant_id: str,
session_id: str,
latency_ms: int,
) -> AuditRecord:
"""
[AC-MARH-12] End trace and create audit record.
"""
trace = self._traces.get(request_id)
if not trace:
logger.warning(f"[AC-MARH-12] Trace not found for end: {request_id}")
return AuditRecord(
tenant_id=tenant_id,
session_id=session_id,
request_id=request_id,
generation_id=str(uuid.uuid4()),
mode=ExecutionMode.FIXED,
latency_ms=latency_ms,
)
audit = AuditRecord(
tenant_id=tenant_id,
session_id=session_id,
request_id=request_id,
generation_id=trace.generation_id or str(uuid.uuid4()),
mode=trace.mode,
intent=trace.intent,
tool_calls=[tc.model_dump() for tc in trace.tool_calls] if trace.tool_calls else [],
guardrail_triggered=trace.guardrail_triggered or False,
guardrail_rule_id=trace.guardrail_rule_id,
interrupt_consumed=trace.interrupt_consumed or False,
fallback_reason_code=trace.fallback_reason_code,
react_iterations=trace.react_iterations or 0,
latency_ms=latency_ms,
created_at=time.strftime("%Y-%m-%d %H:%M:%S"),
)
self._audit_records.append(audit)
if request_id in self._traces:
del self._traces[request_id]
logger.info(
f"[AC-MARH-12] Trace ended: request_id={request_id}, "
f"mode={trace.mode.value}, latency_ms={latency_ms}"
)
return audit
def get_audit_records(
self,
tenant_id: str,
session_id: str | None = None,
limit: int = 100,
) -> list[AuditRecord]:
"""Get audit records for a tenant/session."""
records = [
r for r in self._audit_records
if r.tenant_id == tenant_id
]
if session_id:
records = [r for r in records if r.session_id == session_id]
return records[-limit:]
def clear_audit_records(self, tenant_id: str | None = None) -> int:
"""Clear audit records, optionally filtered by tenant."""
if tenant_id:
original_count = len(self._audit_records)
self._audit_records = [
r for r in self._audit_records
if r.tenant_id != tenant_id
]
return original_count - len(self._audit_records)
else:
count = len(self._audit_records)
self._audit_records = []
return count