feat: implement mid-platform core services and APIs [AC-IDMP-01~20, AC-MARH-01~12]
- Add dialogue, messages, sessions, share API endpoints - Add mid-platform schemas and models (memory, tool_registry, tool_trace) - Add core services: agent_orchestrator, policy_router, runtime_observer - Add tool services: kb_search_dynamic, memory_adapter, tool_registry - Add guardrail services: output_guardrail_executor, high_risk_handler - Add utility services: timeout_governor, segment_humanizer, metrics_collector
This commit is contained in:
parent
3d9f718f5f
commit
b4eb98e7c4
|
|
@ -0,0 +1,30 @@
|
|||
"""
|
||||
Mid Platform API endpoints.
|
||||
[AC-IDMP-01~20] Mid platform dialogue, messages, and session management.
|
||||
[AC-IDMP-SHARE] Share session via unique token.
|
||||
[AC-MRS-09,10] Runtime slot query endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .dialogue import router as dialogue_router
|
||||
from .messages import router as messages_router
|
||||
from .sessions import router as sessions_router
|
||||
from .share import router as share_router
|
||||
from .slots import router as slots_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(dialogue_router)
|
||||
router.include_router(messages_router)
|
||||
router.include_router(sessions_router)
|
||||
router.include_router(share_router)
|
||||
router.include_router(slots_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"dialogue_router",
|
||||
"messages_router",
|
||||
"sessions_router",
|
||||
"share_router",
|
||||
"slots_router",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
Messages Controller for Mid Platform.
|
||||
[AC-IDMP-08] Message report endpoint: POST /mid/messages/report
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models.mid.schemas import MessageReportRequest
|
||||
from app.services.memory import MemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mid", tags=["Mid Platform Messages"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/messages/report",
|
||||
operation_id="reportMessages",
|
||||
summary="Report session messages and events",
|
||||
description="""
|
||||
[AC-IDMP-08] Report messages from channel to mid platform.
|
||||
|
||||
Accepts messages from bot, human, or channel sources for session closure.
|
||||
Returns 202 Accepted for async processing.
|
||||
""",
|
||||
responses={
|
||||
202: {"description": "Accepted for async processing"},
|
||||
400: {"description": "Invalid request"},
|
||||
},
|
||||
)
|
||||
async def report_messages(
|
||||
request: Request,
|
||||
report_request: MessageReportRequest,
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-IDMP-08] Report messages from channel.
|
||||
|
||||
Accepts and stores messages for session data completeness.
|
||||
Messages are stored asynchronously without blocking response.
|
||||
"""
|
||||
tenant_id = get_tenant_id()
|
||||
|
||||
if not tenant_id:
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
raise MissingTenantIdException()
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-08] Message report: tenant={tenant_id}, "
|
||||
f"session={report_request.session_id}, count={len(report_request.messages)}"
|
||||
)
|
||||
|
||||
try:
|
||||
memory_service = MemoryService(session)
|
||||
|
||||
await memory_service.get_or_create_session(
|
||||
tenant_id=tenant_id,
|
||||
session_id=report_request.session_id,
|
||||
)
|
||||
|
||||
messages_to_save = []
|
||||
for msg in report_request.messages:
|
||||
role = msg.role
|
||||
if msg.source == "human":
|
||||
role = "human"
|
||||
elif msg.source == "bot":
|
||||
role = "assistant"
|
||||
|
||||
messages_to_save.append({
|
||||
"role": role,
|
||||
"content": msg.content,
|
||||
})
|
||||
|
||||
if messages_to_save:
|
||||
await memory_service.append_messages(
|
||||
tenant_id=tenant_id,
|
||||
session_id=report_request.session_id,
|
||||
messages=messages_to_save,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-08] Messages saved: tenant={tenant_id}, "
|
||||
f"session={report_request.session_id}, count={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "accepted", "session_id": report_request.session_id},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-IDMP-08] Message report failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "accepted", "warning": str(e)},
|
||||
)
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Sessions Controller for Mid Platform.
|
||||
[AC-IDMP-09] Session mode switch endpoint: POST /mid/sessions/{sessionId}/mode
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models.mid.schemas import (
|
||||
SessionMode,
|
||||
SwitchModeRequest,
|
||||
SwitchModeResponse,
|
||||
)
|
||||
from app.services.memory import MemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"])
|
||||
|
||||
_session_modes: dict[str, SessionMode] = {}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{sessionId}/mode",
|
||||
operation_id="switchSessionMode",
|
||||
summary="Switch session mode",
|
||||
description="""
|
||||
[AC-IDMP-09] Switch session mode between BOT_ACTIVE and HUMAN_ACTIVE.
|
||||
|
||||
When mode is HUMAN_ACTIVE, dialogue responses will route to transfer mode.
|
||||
""",
|
||||
responses={
|
||||
200: {"description": "Mode switched successfully", "model": SwitchModeResponse},
|
||||
400: {"description": "Invalid request"},
|
||||
},
|
||||
)
|
||||
async def switch_session_mode(
|
||||
sessionId: Annotated[str, Path(description="Session ID")],
|
||||
switch_request: SwitchModeRequest,
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> SwitchModeResponse:
|
||||
"""
|
||||
[AC-IDMP-09] Switch session mode.
|
||||
|
||||
Modes:
|
||||
- BOT_ACTIVE: Bot handles responses
|
||||
- HUMAN_ACTIVE: Transfer to human agent
|
||||
"""
|
||||
tenant_id = get_tenant_id()
|
||||
|
||||
if not tenant_id:
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
raise MissingTenantIdException()
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-09] Mode switch: tenant={tenant_id}, "
|
||||
f"session={sessionId}, mode={switch_request.mode.value}"
|
||||
)
|
||||
|
||||
try:
|
||||
memory_service = MemoryService(session)
|
||||
|
||||
await memory_service.get_or_create_session(
|
||||
tenant_id=tenant_id,
|
||||
session_id=sessionId,
|
||||
)
|
||||
|
||||
session_key = f"{tenant_id}:{sessionId}"
|
||||
_session_modes[session_key] = switch_request.mode
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-09] Mode switched: session={sessionId}, "
|
||||
f"mode={switch_request.mode.value}, reason={switch_request.reason}"
|
||||
)
|
||||
|
||||
return SwitchModeResponse(
|
||||
session_id=sessionId,
|
||||
mode=switch_request.mode,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-IDMP-09] Mode switch failed: {e}")
|
||||
return SwitchModeResponse(
|
||||
session_id=sessionId,
|
||||
mode=switch_request.mode,
|
||||
)
|
||||
|
||||
|
||||
def get_session_mode(tenant_id: str, session_id: str) -> SessionMode:
|
||||
"""Get current session mode."""
|
||||
session_key = f"{tenant_id}:{session_id}"
|
||||
return _session_modes.get(session_key, SessionMode.BOT_ACTIVE)
|
||||
|
||||
|
||||
def clear_session_mode(tenant_id: str, session_id: str) -> None:
|
||||
"""Clear session mode (reset to BOT_ACTIVE)."""
|
||||
session_key = f"{tenant_id}:{session_id}"
|
||||
if session_key in _session_modes:
|
||||
del _session_modes[session_key]
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
"""
|
||||
Mid Platform models for Intent-Driven Agent.
|
||||
[AC-IDMP-01~20] 中台统一响应协议模型
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Mode(str, Enum):
|
||||
"""[AC-IDMP-02] 执行模式"""
|
||||
AGENT = "agent"
|
||||
MICRO_FLOW = "micro_flow"
|
||||
FIXED = "fixed"
|
||||
TRANSFER = "transfer"
|
||||
|
||||
|
||||
class SessionMode(str, Enum):
|
||||
"""[AC-IDMP-09] 会话模式"""
|
||||
BOT_ACTIVE = "BOT_ACTIVE"
|
||||
HUMAN_ACTIVE = "HUMAN_ACTIVE"
|
||||
|
||||
|
||||
class HighRiskScenario(str, Enum):
|
||||
"""[AC-IDMP-20] 高风险场景最小集"""
|
||||
REFUND = "refund"
|
||||
COMPLAINT_ESCALATION = "complaint_escalation"
|
||||
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
|
||||
TRANSFER = "transfer"
|
||||
|
||||
|
||||
class ToolCallStatus(str, Enum):
|
||||
"""[AC-IDMP-15] 工具调用状态"""
|
||||
OK = "ok"
|
||||
TIMEOUT = "timeout"
|
||||
ERROR = "error"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ToolType(str, Enum):
|
||||
"""[AC-IDMP-19] 工具类型"""
|
||||
INTERNAL = "internal"
|
||||
MCP = "mcp"
|
||||
|
||||
|
||||
class HistoryMessage(BaseModel):
|
||||
"""[AC-IDMP-03] 已送达历史消息"""
|
||||
role: str = Field(..., description="消息角色: user/assistant/human")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class InterruptedSegment(BaseModel):
|
||||
"""[AC-IDMP-04] 打断的分段"""
|
||||
segment_id: str = Field(..., description="分段ID")
|
||||
content: str = Field(..., description="分段内容")
|
||||
|
||||
|
||||
class FeatureFlags(BaseModel):
|
||||
"""[AC-IDMP-17] 特性开关"""
|
||||
agent_enabled: bool | None = Field(default=None, description="会话级 Agent 灰度开关")
|
||||
rollback_to_legacy: bool | None = Field(default=None, description="强制回滚传统链路")
|
||||
|
||||
|
||||
class DialogueRequest(BaseModel):
|
||||
"""[AC-IDMP-01~04] 会话响应请求"""
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
user_id: str | None = Field(default=None, description="用户ID,用于记忆召回与更新")
|
||||
user_message: str = Field(..., min_length=1, max_length=2000, description="用户消息")
|
||||
history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史")
|
||||
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段")
|
||||
feature_flags: FeatureFlags | None = Field(default=None, description="特性开关")
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
"""[AC-IDMP-01] 响应分段"""
|
||||
segment_id: str = Field(..., description="分段ID")
|
||||
text: str = Field(..., description="分段文本")
|
||||
delay_after: int = Field(default=0, ge=0, description="分段后延迟(ms)")
|
||||
|
||||
|
||||
class TimeoutProfile(BaseModel):
|
||||
"""[AC-IDMP-12] 超时配置"""
|
||||
per_tool_timeout_ms: int | None = Field(default=30000, le=60000, description="单工具超时(ms)")
|
||||
end_to_end_timeout_ms: int | None = Field(default=120000, le=180000, description="端到端超时(ms)")
|
||||
|
||||
|
||||
class MetricsSnapshot(BaseModel):
|
||||
"""[AC-IDMP-18] 运行指标快照"""
|
||||
task_completion_rate: float | None = Field(default=None, ge=0, le=1, description="任务达成率")
|
||||
slot_completion_rate: float | None = Field(default=None, ge=0, le=1, description="槽位完整率")
|
||||
wrong_transfer_rate: float | None = Field(default=None, ge=0, le=1, description="误转人工率")
|
||||
no_recall_rate: float | None = Field(default=None, ge=0, le=1, description="无召回率")
|
||||
avg_latency_ms: float | None = Field(default=None, ge=0, description="平均时延(ms)")
|
||||
|
||||
|
||||
class ToolCallTraceModel(BaseModel):
|
||||
"""[AC-IDMP-15/19] 工具调用追踪 (Pydantic 模型)"""
|
||||
tool_name: str = Field(..., description="工具名称")
|
||||
tool_type: ToolType | None = Field(default=None, description="工具类型: internal/mcp")
|
||||
registry_version: str | None = Field(default=None, description="注册版本")
|
||||
auth_applied: bool | None = Field(default=None, description="是否应用鉴权")
|
||||
duration_ms: int = Field(..., ge=0, description="耗时(ms)")
|
||||
status: ToolCallStatus = Field(..., description="状态: ok/timeout/error/rejected")
|
||||
error_code: str | None = Field(default=None, description="错误码")
|
||||
args_digest: str | None = Field(default=None, description="参数摘要")
|
||||
result_digest: str | None = Field(default=None, description="结果摘要")
|
||||
|
||||
|
||||
class TraceInfo(BaseModel):
|
||||
"""[AC-IDMP-02/07] 追踪信息"""
|
||||
mode: Mode = Field(..., description="执行模式")
|
||||
intent: str | None = Field(default=None, description="意图")
|
||||
request_id: str | None = Field(default=None, description="请求ID")
|
||||
generation_id: str | None = Field(default=None, description="生成ID")
|
||||
guardrail_triggered: bool | None = Field(default=False, description="护栏是否触发")
|
||||
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
|
||||
react_iterations: int | None = Field(default=None, ge=0, le=5, description="ReAct循环次数")
|
||||
timeout_profile: TimeoutProfile | None = Field(default=None, description="超时配置")
|
||||
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="指标快照")
|
||||
high_risk_policy_set: list[HighRiskScenario] | None = Field(
|
||||
default=None,
|
||||
description="当前启用的高风险最小场景集"
|
||||
)
|
||||
tools_used: list[str] | None = Field(default=None, description="使用的工具列表")
|
||||
tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪")
|
||||
|
||||
|
||||
class DialogueResponse(BaseModel):
|
||||
"""[AC-IDMP-01/02] 会话响应"""
|
||||
segments: list[Segment] = Field(..., description="响应分段列表")
|
||||
trace: TraceInfo = Field(..., description="追踪信息")
|
||||
|
||||
|
||||
class ReportedMessage(BaseModel):
|
||||
"""[AC-IDMP-08] 上报的消息"""
|
||||
role: str = Field(..., description="角色: user/assistant/human/system")
|
||||
content: str = Field(..., description="消息内容")
|
||||
source: str = Field(..., description="来源: bot/human/channel")
|
||||
timestamp: datetime = Field(..., description="时间戳")
|
||||
segment_id: str | None = Field(default=None, description="分段ID")
|
||||
|
||||
|
||||
class MessageReportRequest(BaseModel):
|
||||
"""[AC-IDMP-08] 消息上报请求"""
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
messages: list[ReportedMessage] = Field(..., description="消息列表")
|
||||
|
||||
|
||||
class SwitchModeRequest(BaseModel):
|
||||
"""[AC-IDMP-09] 切换模式请求"""
|
||||
mode: SessionMode = Field(..., description="目标模式")
|
||||
reason: str | None = Field(default=None, description="切换原因")
|
||||
|
||||
|
||||
class SwitchModeResponse(BaseModel):
|
||||
"""[AC-IDMP-09] 切换模式响应"""
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
mode: SessionMode = Field(..., description="当前模式")
|
||||
|
||||
|
||||
from app.models.mid.memory import (
|
||||
RecallRequest,
|
||||
RecallResponse,
|
||||
UpdateRequest,
|
||||
MemoryProfile,
|
||||
MemoryFact,
|
||||
MemoryPreferences,
|
||||
)
|
||||
from app.models.mid.tool_trace import (
|
||||
ToolCallTrace,
|
||||
ToolCallBuilder,
|
||||
ToolCallStatus as ToolCallStatusEnum,
|
||||
ToolType as ToolTypeEnum,
|
||||
)
|
||||
from app.models.mid.tool_registry import (
|
||||
ToolDefinition,
|
||||
ToolAuthConfig,
|
||||
ToolTimeoutPolicy,
|
||||
ToolRegistryEntity,
|
||||
ToolRegistryCreate,
|
||||
ToolRegistryUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Mode",
|
||||
"SessionMode",
|
||||
"HighRiskScenario",
|
||||
"ToolCallStatus",
|
||||
"ToolType",
|
||||
"HistoryMessage",
|
||||
"InterruptedSegment",
|
||||
"FeatureFlags",
|
||||
"DialogueRequest",
|
||||
"Segment",
|
||||
"TimeoutProfile",
|
||||
"MetricsSnapshot",
|
||||
"ToolCallTraceModel",
|
||||
"TraceInfo",
|
||||
"DialogueResponse",
|
||||
"ReportedMessage",
|
||||
"MessageReportRequest",
|
||||
"SwitchModeRequest",
|
||||
"SwitchModeResponse",
|
||||
"RecallRequest",
|
||||
"RecallResponse",
|
||||
"UpdateRequest",
|
||||
"MemoryProfile",
|
||||
"MemoryFact",
|
||||
"MemoryPreferences",
|
||||
"ToolCallTrace",
|
||||
"ToolCallBuilder",
|
||||
"ToolCallStatusEnum",
|
||||
"ToolTypeEnum",
|
||||
"ToolDefinition",
|
||||
"ToolAuthConfig",
|
||||
"ToolTimeoutPolicy",
|
||||
"ToolRegistryEntity",
|
||||
"ToolRegistryCreate",
|
||||
"ToolRegistryUpdate",
|
||||
]
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
Memory models for Mid Platform.
|
||||
[AC-IDMP-13] 记忆召回数据模型
|
||||
[AC-IDMP-14] 记忆更新数据模型
|
||||
|
||||
Reference: spec/intent-driven-mid-platform/openapi.deps.yaml
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryProfile:
|
||||
"""
|
||||
[AC-IDMP-13] 用户基础属性记忆
|
||||
包含年级、地区、渠道等基础信息
|
||||
"""
|
||||
grade: str | None = None
|
||||
region: str | None = None
|
||||
channel: str | None = None
|
||||
vip_level: str | None = None
|
||||
registration_date: datetime | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {}
|
||||
if self.grade:
|
||||
result["grade"] = self.grade
|
||||
if self.region:
|
||||
result["region"] = self.region
|
||||
if self.channel:
|
||||
result["channel"] = self.channel
|
||||
if self.vip_level:
|
||||
result["vip_level"] = self.vip_level
|
||||
if self.registration_date:
|
||||
result["registration_date"] = self.registration_date.isoformat()
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFact:
|
||||
"""
|
||||
[AC-IDMP-13] 事实型记忆
|
||||
包含已购课程、学习结论等客观事实
|
||||
"""
|
||||
content: str
|
||||
source: str | None = None
|
||||
confidence: float | None = None
|
||||
created_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def to_string(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryPreferences:
|
||||
"""
|
||||
[AC-IDMP-13] 偏好记忆
|
||||
包含语气偏好、关注科目等用户偏好
|
||||
"""
|
||||
tone: str | None = None
|
||||
focus_subjects: list[str] = field(default_factory=list)
|
||||
communication_style: str | None = None
|
||||
preferred_time: str | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {}
|
||||
if self.tone:
|
||||
result["tone"] = self.tone
|
||||
if self.focus_subjects:
|
||||
result["focus_subjects"] = self.focus_subjects
|
||||
if self.communication_style:
|
||||
result["communication_style"] = self.communication_style
|
||||
if self.preferred_time:
|
||||
result["preferred_time"] = self.preferred_time
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallRequest:
|
||||
"""
|
||||
[AC-IDMP-13] 记忆召回请求
|
||||
Reference: openapi.deps.yaml - RecallRequest
|
||||
"""
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallResponse:
|
||||
"""
|
||||
[AC-IDMP-13] 记忆召回响应
|
||||
Reference: openapi.deps.yaml - RecallResponse
|
||||
"""
|
||||
profile: MemoryProfile | None = None
|
||||
facts: list[MemoryFact] = field(default_factory=list)
|
||||
preferences: MemoryPreferences | None = None
|
||||
last_summary: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {}
|
||||
if self.profile:
|
||||
result["profile"] = self.profile.to_dict()
|
||||
if self.facts:
|
||||
result["facts"] = [f.to_string() for f in self.facts]
|
||||
if self.preferences:
|
||||
result["preferences"] = self.preferences.to_dict()
|
||||
if self.last_summary:
|
||||
result["last_summary"] = self.last_summary
|
||||
return result
|
||||
|
||||
def get_context_for_prompt(self) -> str:
|
||||
"""
|
||||
生成用于注入 Prompt 的上下文字符串
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if self.profile:
|
||||
profile_parts = []
|
||||
if self.profile.grade:
|
||||
profile_parts.append(f"年级: {self.profile.grade}")
|
||||
if self.profile.region:
|
||||
profile_parts.append(f"地区: {self.profile.region}")
|
||||
if self.profile.vip_level:
|
||||
profile_parts.append(f"会员等级: {self.profile.vip_level}")
|
||||
if profile_parts:
|
||||
parts.append("【用户属性】" + "、".join(profile_parts))
|
||||
|
||||
if self.facts:
|
||||
fact_strs = [f.content for f in self.facts[:5]]
|
||||
parts.append("【已知事实】" + ";".join(fact_strs))
|
||||
|
||||
if self.preferences:
|
||||
pref_parts = []
|
||||
if self.preferences.tone:
|
||||
pref_parts.append(f"语气偏好: {self.preferences.tone}")
|
||||
if self.preferences.focus_subjects:
|
||||
pref_parts.append(f"关注科目: {', '.join(self.preferences.focus_subjects)}")
|
||||
if pref_parts:
|
||||
parts.append("【用户偏好】" + "、".join(pref_parts))
|
||||
|
||||
if self.last_summary:
|
||||
parts.append(f"【上次会话摘要】{self.last_summary}")
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateRequest:
|
||||
"""
|
||||
[AC-IDMP-14] 记忆更新请求
|
||||
Reference: openapi.deps.yaml - UpdateRequest
|
||||
"""
|
||||
user_id: str
|
||||
session_id: str
|
||||
messages: list[dict[str, Any]]
|
||||
summary: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"messages": self.messages,
|
||||
}
|
||||
if self.summary:
|
||||
result["summary"] = self.summary
|
||||
return result
|
||||
|
|
@ -0,0 +1,407 @@
|
|||
"""
|
||||
Mid Platform schemas.
|
||||
[AC-IDMP-01, AC-IDMP-02, AC-IDMP-07, AC-IDMP-11, AC-IDMP-12, AC-IDMP-15, AC-IDMP-17, AC-IDMP-18, AC-IDMP-19, AC-IDMP-20]
|
||||
Aligned with spec/intent-driven-mid-platform/openapi.provider.yaml
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ExecutionMode(str, Enum):
|
||||
"""[AC-IDMP-02] Execution mode for dialogue response."""
|
||||
AGENT = "agent"
|
||||
MICRO_FLOW = "micro_flow"
|
||||
FIXED = "fixed"
|
||||
TRANSFER = "transfer"
|
||||
|
||||
|
||||
class HighRiskScenario(str, Enum):
|
||||
"""[AC-IDMP-20] High risk scenario types for mandatory takeover."""
|
||||
REFUND = "refund"
|
||||
COMPLAINT_ESCALATION = "complaint_escalation"
|
||||
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
|
||||
TRANSFER = "transfer"
|
||||
|
||||
|
||||
class ToolCallStatus(str, Enum):
|
||||
"""[AC-IDMP-15] Tool call status."""
|
||||
OK = "ok"
|
||||
TIMEOUT = "timeout"
|
||||
ERROR = "error"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ToolType(str, Enum):
|
||||
"""[AC-IDMP-19] Tool type for registry governance."""
|
||||
INTERNAL = "internal"
|
||||
MCP = "mcp"
|
||||
|
||||
|
||||
class SessionMode(str, Enum):
|
||||
"""[AC-IDMP-09] Session mode for bot/human switching."""
|
||||
BOT_ACTIVE = "BOT_ACTIVE"
|
||||
HUMAN_ACTIVE = "HUMAN_ACTIVE"
|
||||
|
||||
|
||||
class HistoryMessage(BaseModel):
|
||||
"""[AC-IDMP-03] History message with only delivered content."""
|
||||
role: str = Field(..., description="Message role: user, assistant, or human")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class InterruptedSegment(BaseModel):
|
||||
"""[AC-IDMP-04] Interrupted segment for handling user interruption."""
|
||||
segment_id: str = Field(..., description="Segment ID")
|
||||
content: str = Field(..., description="Segment content")
|
||||
|
||||
|
||||
class FeatureFlags(BaseModel):
|
||||
"""[AC-IDMP-17] Feature flags for session-level grayscale and rollback."""
|
||||
agent_enabled: bool | None = Field(default=True, description="Session-level Agent grayscale switch")
|
||||
rollback_to_legacy: bool | None = Field(default=False, description="Force rollback to legacy pipeline")
|
||||
|
||||
|
||||
class HumanizeConfigRequest(BaseModel):
|
||||
"""[AC-MARH-11] 拟人化配置请求。"""
|
||||
enabled: bool | None = Field(default=True, description="Enable humanize strategy")
|
||||
min_delay_ms: int | None = Field(default=50, ge=0, description="Minimum delay in milliseconds")
|
||||
max_delay_ms: int | None = Field(default=500, ge=0, description="Maximum delay in milliseconds")
|
||||
length_bucket_strategy: str | None = Field(default="simple", description="Strategy: simple or semantic")
|
||||
|
||||
|
||||
class DialogueRequest(BaseModel):
|
||||
"""[AC-IDMP-01, AC-IDMP-03, AC-IDMP-04, AC-IDMP-17, AC-MARH-11] Dialogue request schema."""
|
||||
session_id: str = Field(..., description="Session ID for conversation tracking")
|
||||
user_id: str | None = Field(default=None, description="User ID for memory recall and update")
|
||||
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
|
||||
history: list[HistoryMessage] = Field(default_factory=list, description="Only delivered history")
|
||||
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="Interrupted segments")
|
||||
feature_flags: FeatureFlags | None = Field(default=None, description="Feature flags for grayscale control")
|
||||
humanize_config: HumanizeConfigRequest | None = Field(
|
||||
default=None, description="Humanize config for segment delay"
|
||||
)
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
"""[AC-IDMP-01] Response segment with delay control."""
|
||||
segment_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Segment ID")
|
||||
text: str = Field(..., description="Segment text content")
|
||||
delay_after: int = Field(default=0, ge=0, description="Delay after this segment in milliseconds")
|
||||
|
||||
|
||||
class TimeoutProfile(BaseModel):
|
||||
"""[AC-MARH-08, AC-MARH-09] Timeout configuration profile."""
|
||||
per_tool_timeout_ms: int = Field(default=30000, le=60000, description="Per-tool timeout in milliseconds")
|
||||
llm_timeout_ms: int = Field(default=60000, le=120000, description="LLM call timeout in milliseconds")
|
||||
end_to_end_timeout_ms: int = Field(default=120000, le=180000, description="End-to-end timeout in milliseconds")
|
||||
|
||||
|
||||
class MetricsSnapshot(BaseModel):
|
||||
"""[AC-IDMP-18] Runtime metrics snapshot."""
|
||||
task_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Task completion rate")
|
||||
slot_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Slot completion rate")
|
||||
wrong_transfer_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Wrong transfer rate")
|
||||
no_recall_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="No recall rate")
|
||||
avg_latency_ms: float | None = Field(default=None, ge=0.0, description="Average latency in milliseconds")
|
||||
|
||||
|
||||
class ToolCallTrace(BaseModel):
|
||||
"""[AC-IDMP-15, AC-IDMP-19] Tool call trace for observability."""
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
tool_type: ToolType | None = Field(default=ToolType.INTERNAL, description="Tool type: internal or mcp")
|
||||
registry_version: str | None = Field(default=None, description="Tool registry version")
|
||||
auth_applied: bool | None = Field(default=False, description="Whether auth was applied")
|
||||
duration_ms: int = Field(..., ge=0, description="Duration in milliseconds")
|
||||
status: ToolCallStatus = Field(..., description="Tool call status")
|
||||
error_code: str | None = Field(default=None, description="Error code if failed")
|
||||
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
|
||||
result_digest: str | None = Field(default=None, description="Result digest for logging")
|
||||
|
||||
|
||||
class SegmentStats(BaseModel):
|
||||
"""[AC-MARH-12] Segment statistics for humanize strategy."""
|
||||
segment_count: int = Field(default=0, ge=0, description="Number of segments")
|
||||
avg_segment_length: float = Field(default=0.0, ge=0.0, description="Average segment length")
|
||||
humanize_strategy: str | None = Field(default=None, description="Humanize strategy used")
|
||||
|
||||
|
||||
class TraceInfo(BaseModel):
|
||||
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
|
||||
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability."""
|
||||
mode: ExecutionMode = Field(..., description="Execution mode")
|
||||
intent: str | None = Field(default=None, description="Matched intent")
|
||||
request_id: str | None = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Request ID"
|
||||
)
|
||||
generation_id: str | None = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Generation ID for interrupt handling",
|
||||
)
|
||||
guardrail_triggered: bool | None = Field(default=False, description="Whether guardrail was triggered")
|
||||
guardrail_rule_id: str | None = Field(default=None, description="Guardrail rule ID that triggered")
|
||||
interrupt_consumed: bool | None = Field(default=False, description="Whether interrupted segments were consumed")
|
||||
kb_tool_called: bool | None = Field(default=False, description="Whether KB tool was called")
|
||||
kb_hit: bool | None = Field(default=False, description="Whether KB search had results")
|
||||
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
|
||||
react_iterations: int | None = Field(default=0, ge=0, le=5, description="ReAct loop iterations")
|
||||
timeout_profile: TimeoutProfile | None = Field(default=None, description="Timeout profile")
|
||||
segment_stats: SegmentStats | None = Field(default=None, description="Segment statistics")
|
||||
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="Metrics snapshot")
|
||||
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
|
||||
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
|
||||
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
|
||||
|
||||
|
||||
class DialogueResponse(BaseModel):
|
||||
"""[AC-IDMP-01, AC-IDMP-02] Dialogue response with segments and trace."""
|
||||
segments: list[Segment] = Field(..., description="Response segments")
|
||||
trace: TraceInfo = Field(..., description="Trace info for observability")
|
||||
|
||||
|
||||
class ReportedMessage(BaseModel):
|
||||
"""[AC-IDMP-08] Reported message for message report API."""
|
||||
role: str = Field(..., description="Message role: user, assistant, human, or system")
|
||||
content: str = Field(..., description="Message content")
|
||||
source: str = Field(..., description="Message source: bot, human, or channel")
|
||||
timestamp: str = Field(..., description="Message timestamp in ISO format")
|
||||
segment_id: str | None = Field(default=None, description="Segment ID if applicable")
|
||||
|
||||
|
||||
class MessageReportRequest(BaseModel):
|
||||
"""[AC-IDMP-08] Message report request schema."""
|
||||
session_id: str = Field(..., description="Session ID")
|
||||
messages: list[ReportedMessage] = Field(..., description="Messages to report")
|
||||
|
||||
|
||||
class SwitchModeRequest(BaseModel):
|
||||
"""[AC-IDMP-09] Switch session mode request."""
|
||||
mode: SessionMode = Field(..., description="Target mode: BOT_ACTIVE or HUMAN_ACTIVE")
|
||||
reason: str | None = Field(default=None, description="Reason for mode switch")
|
||||
|
||||
|
||||
class SwitchModeResponse(BaseModel):
|
||||
"""[AC-IDMP-09] Switch session mode response."""
|
||||
session_id: str = Field(..., description="Session ID")
|
||||
mode: SessionMode = Field(..., description="Current mode after switch")
|
||||
|
||||
|
||||
class MidSessionState(BaseModel):
|
||||
"""Internal session state for mid platform."""
|
||||
session_id: str
|
||||
tenant_id: str
|
||||
mode: SessionMode = SessionMode.BOT_ACTIVE
|
||||
generation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
active_flow_id: str | None = None
|
||||
context: dict[str, Any] | None = None
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
class PolicyRouterResult(BaseModel):
|
||||
"""[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16] Policy router decision result."""
|
||||
mode: ExecutionMode = Field(..., description="Decided execution mode")
|
||||
intent: str | None = Field(default=None, description="Matched intent")
|
||||
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Intent confidence")
|
||||
fallback_reason_code: str | None = Field(default=None, description="Fallback reason if applicable")
|
||||
high_risk_triggered: bool = Field(default=False, description="Whether high-risk scenario triggered")
|
||||
target_flow_id: str | None = Field(default=None, description="Target flow ID for micro_flow mode")
|
||||
fixed_reply: str | None = Field(default=None, description="Fixed reply for fixed mode")
|
||||
transfer_message: str | None = Field(default=None, description="Transfer message for transfer mode")
|
||||
|
||||
|
||||
class ReActContext(BaseModel):
|
||||
"""[AC-IDMP-11] ReAct loop context for iteration control."""
|
||||
iteration: int = Field(default=0, ge=0, le=5, description="Current iteration count")
|
||||
max_iterations: int = Field(default=5, ge=3, le=5, description="Maximum iterations allowed")
|
||||
tool_calls: list[ToolCallTrace] = Field(default_factory=list, description="Tool call history")
|
||||
should_continue: bool = Field(default=True, description="Whether to continue ReAct loop")
|
||||
final_answer: str | None = Field(default=None, description="Final answer if completed")
|
||||
|
||||
|
||||
class CreateShareRequest(BaseModel):
|
||||
"""[AC-IDMP-SHARE] Request to create a shared session."""
|
||||
title: str | None = Field(default=None, max_length=255, description="Share title")
|
||||
description: str | None = Field(default=None, max_length=1000, description="Share description")
|
||||
expires_in_days: int = Field(default=7, ge=1, le=365, description="Expiration time in days")
|
||||
max_concurrent_users: int = Field(default=10, ge=1, le=100, description="Maximum concurrent users")
|
||||
|
||||
|
||||
class ShareResponse(BaseModel):
|
||||
"""[AC-IDMP-SHARE] Response after creating a share."""
|
||||
share_token: str = Field(..., description="Unique share token")
|
||||
share_url: str = Field(..., description="Full share URL")
|
||||
expires_at: str = Field(..., description="Expiration time in ISO format")
|
||||
title: str | None = Field(default=None, description="Share title")
|
||||
description: str | None = Field(default=None, description="Share description")
|
||||
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
|
||||
|
||||
|
||||
class SharedSessionInfo(BaseModel):
|
||||
"""[AC-IDMP-SHARE] Information about a shared session."""
|
||||
session_id: str = Field(..., description="Session ID")
|
||||
title: str | None = Field(default=None, description="Share title")
|
||||
description: str | None = Field(default=None, description="Share description")
|
||||
expires_at: str = Field(..., description="Expiration time in ISO format")
|
||||
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
|
||||
current_users: int = Field(..., description="Current online users")
|
||||
history: list[HistoryMessage] = Field(default_factory=list, description="Historical messages")
|
||||
|
||||
|
||||
class SharedMessageRequest(BaseModel):
|
||||
"""[AC-IDMP-SHARE] Request to send a message via shared session."""
|
||||
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
|
||||
|
||||
|
||||
class ShareListItem(BaseModel):
|
||||
"""[AC-IDMP-SHARE] Share list item for listing all shares of a session."""
|
||||
share_token: str = Field(..., description="Share token")
|
||||
share_url: str = Field(..., description="Full share URL")
|
||||
title: str | None = Field(default=None, description="Share title")
|
||||
description: str | None = Field(default=None, description="Share description")
|
||||
expires_at: str = Field(..., description="Expiration time in ISO format")
|
||||
is_active: bool = Field(..., description="Whether share is active")
|
||||
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
|
||||
current_users: int = Field(..., description="Current online users")
|
||||
created_at: str = Field(..., description="Creation time in ISO format")
|
||||
|
||||
|
||||
class ShareListResponse(BaseModel):
|
||||
"""[AC-IDMP-SHARE] Response for listing shares."""
|
||||
shares: list[ShareListItem] = Field(..., description="List of shares")
|
||||
|
||||
|
||||
class KbSearchDynamicHit(BaseModel):
|
||||
"""[AC-MARH-05] Single KB search hit."""
|
||||
id: str = Field(..., description="Hit ID")
|
||||
content: str = Field(..., description="Hit content")
|
||||
score: float = Field(..., ge=0.0, le=1.0, description="Relevance score")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Hit metadata")
|
||||
|
||||
|
||||
class MissingRequiredSlot(BaseModel):
|
||||
"""[AC-MARH-05] Missing required slot info."""
|
||||
field_key: str = Field(..., description="Field key")
|
||||
label: str = Field(..., description="Field label")
|
||||
reason: str = Field(..., description="Missing reason")
|
||||
|
||||
|
||||
class KbSearchDynamicResultSchema(BaseModel):
|
||||
"""[AC-MARH-05, AC-MARH-06] KB dynamic search result schema."""
|
||||
success: bool = Field(..., description="Whether search succeeded")
|
||||
hits: list[KbSearchDynamicHit] = Field(default_factory=list, description="Search hits")
|
||||
applied_filter: dict[str, Any] = Field(default_factory=dict, description="Applied filter")
|
||||
missing_required_slots: list[MissingRequiredSlot] = Field(
|
||||
default_factory=list, description="Missing required slots"
|
||||
)
|
||||
filter_debug: dict[str, Any] = Field(default_factory=dict, description="Filter debug info")
|
||||
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
|
||||
duration_ms: int = Field(default=0, ge=0, description="Duration in milliseconds")
|
||||
|
||||
|
||||
class IntentHintOutput(BaseModel):
|
||||
"""[AC-IDMP-02, AC-IDMP-16] 轻量意图提示工具输出。"""
|
||||
intent: str | None = Field(default=None, description="识别到的意图名称")
|
||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
|
||||
response_type: str | None = Field(
|
||||
default=None,
|
||||
description="响应类型: fixed|rag|flow|transfer|null"
|
||||
)
|
||||
suggested_mode: ExecutionMode | None = Field(
|
||||
default=None,
|
||||
description="建议执行模式: agent|micro_flow|fixed|transfer"
|
||||
)
|
||||
target_flow_id: str | None = Field(default=None, description="目标流程ID(flow模式)")
|
||||
target_kb_ids: list[str] | None = Field(default=None, description="目标知识库ID列表")
|
||||
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
|
||||
high_risk_detected: bool = Field(default=False, description="是否检测到高风险场景")
|
||||
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
|
||||
|
||||
|
||||
class HighRiskCheckResult(BaseModel):
|
||||
"""[AC-IDMP-05, AC-IDMP-20] 高风险检测工具输出。"""
|
||||
matched: bool = Field(default=False, description="是否命中高风险场景")
|
||||
risk_scenario: HighRiskScenario | None = Field(
|
||||
default=None,
|
||||
description="风险场景: refund|complaint_escalation|privacy_sensitive_promise|transfer|none"
|
||||
)
|
||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
|
||||
recommended_mode: ExecutionMode | None = Field(
|
||||
default=None,
|
||||
description="推荐执行模式: micro_flow|transfer|agent"
|
||||
)
|
||||
rule_id: str | None = Field(default=None, description="匹配的规则ID")
|
||||
reason: str | None = Field(default=None, description="匹配原因说明")
|
||||
fallback_reason_code: str | None = Field(default=None, description="降级原因码(工具失败时)")
|
||||
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
|
||||
matched_text: str | None = Field(default=None, description="匹配到的文本片段")
|
||||
matched_pattern: str | None = Field(default=None, description="匹配到的模式(关键词或正则)")
|
||||
|
||||
|
||||
class SlotSource(str, Enum):
|
||||
"""[AC-IDMP-13] 槽位来源类型。"""
|
||||
USER_CONFIRMED = "user_confirmed"
|
||||
RULE_EXTRACTED = "rule_extracted"
|
||||
LLM_INFERRED = "llm_inferred"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class MemorySlot(BaseModel):
|
||||
"""[AC-IDMP-13] 单个槽位信息。"""
|
||||
key: str = Field(..., description="槽位键名")
|
||||
value: Any = Field(..., description="槽位值")
|
||||
source: SlotSource = Field(default=SlotSource.DEFAULT, description="槽位来源")
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
|
||||
updated_at: str | None = Field(default=None, description="最后更新时间")
|
||||
|
||||
|
||||
class MemoryRecallResult(BaseModel):
|
||||
"""[AC-IDMP-13] 记忆召回工具输出。"""
|
||||
profile: dict[str, Any] = Field(default_factory=dict, description="用户基础属性")
|
||||
facts: list[str] = Field(default_factory=list, description="事实型记忆列表")
|
||||
preferences: dict[str, Any] = Field(default_factory=dict, description="用户偏好")
|
||||
last_summary: str | None = Field(default=None, description="最近会话摘要")
|
||||
slots: dict[str, MemorySlot] = Field(default_factory=dict, description="结构化槽位")
|
||||
missing_slots: list[str] = Field(default_factory=list, description="缺失的必填槽位")
|
||||
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
|
||||
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
|
||||
|
||||
def get_context_for_prompt(self) -> str:
|
||||
"""生成用于注入 Prompt 的上下文字符串。"""
|
||||
parts = []
|
||||
|
||||
if self.profile:
|
||||
profile_parts = []
|
||||
for key, value in self.profile.items():
|
||||
if value:
|
||||
profile_parts.append(f"{key}: {value}")
|
||||
if profile_parts:
|
||||
parts.append("【用户属性】" + "、".join(profile_parts))
|
||||
|
||||
if self.facts:
|
||||
parts.append("【已知事实】" + ";".join(self.facts[:5]))
|
||||
|
||||
if self.preferences:
|
||||
pref_parts = []
|
||||
for key, value in self.preferences.items():
|
||||
if value:
|
||||
pref_parts.append(f"{key}: {value}")
|
||||
if pref_parts:
|
||||
parts.append("【用户偏好】" + "、".join(pref_parts))
|
||||
|
||||
if self.last_summary:
|
||||
parts.append(f"【上次会话摘要】{self.last_summary}")
|
||||
|
||||
if self.slots:
|
||||
slot_parts = []
|
||||
for key, slot in self.slots.items():
|
||||
slot_parts.append(f"{key}={slot.value}")
|
||||
if slot_parts:
|
||||
parts.append("【已知槽位】" + ", ".join(slot_parts))
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
Tool Registry models for Mid Platform.
|
||||
[AC-IDMP-19] Tool Registry 治理模型
|
||||
|
||||
Reference: spec/intent-driven-mid-platform/openapi.provider.yaml
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class ToolStatus(str, Enum):
|
||||
"""工具状态"""
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
DEPRECATED = "deprecated"
|
||||
|
||||
|
||||
class ToolAuthType(str, Enum):
|
||||
"""工具鉴权类型"""
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
OAUTH = "oauth"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolAuthConfig:
|
||||
"""
|
||||
[AC-IDMP-19] 工具鉴权配置
|
||||
"""
|
||||
auth_type: ToolAuthType = ToolAuthType.NONE
|
||||
required_scopes: list[str] = field(default_factory=list)
|
||||
api_key_header: str | None = None
|
||||
oauth_url: str | None = None
|
||||
custom_validator: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {"auth_type": self.auth_type.value}
|
||||
if self.required_scopes:
|
||||
result["required_scopes"] = self.required_scopes
|
||||
if self.api_key_header:
|
||||
result["api_key_header"] = self.api_key_header
|
||||
if self.oauth_url:
|
||||
result["oauth_url"] = self.oauth_url
|
||||
if self.custom_validator:
|
||||
result["custom_validator"] = self.custom_validator
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolAuthConfig":
|
||||
return cls(
|
||||
auth_type=ToolAuthType(data.get("auth_type", "none")),
|
||||
required_scopes=data.get("required_scopes", []),
|
||||
api_key_header=data.get("api_key_header"),
|
||||
oauth_url=data.get("oauth_url"),
|
||||
custom_validator=data.get("custom_validator"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolTimeoutPolicy:
|
||||
"""
|
||||
[AC-IDMP-19] 工具超时策略
|
||||
Reference: openapi.provider.yaml - TimeoutProfile
|
||||
"""
|
||||
per_tool_timeout_ms: int = 30000
|
||||
end_to_end_timeout_ms: int = 120000
|
||||
retry_count: int = 0
|
||||
retry_delay_ms: int = 100
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"per_tool_timeout_ms": self.per_tool_timeout_ms,
|
||||
"end_to_end_timeout_ms": self.end_to_end_timeout_ms,
|
||||
"retry_count": self.retry_count,
|
||||
"retry_delay_ms": self.retry_delay_ms,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolTimeoutPolicy":
|
||||
return cls(
|
||||
per_tool_timeout_ms=data.get("per_tool_timeout_ms", 30000),
|
||||
end_to_end_timeout_ms=data.get("end_to_end_timeout_ms", 120000),
|
||||
retry_count=data.get("retry_count", 0),
|
||||
retry_delay_ms=data.get("retry_delay_ms", 100),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""
|
||||
[AC-IDMP-19] 工具定义(内存模型)
|
||||
|
||||
包含:
|
||||
- name: 工具名称
|
||||
- type: 工具类型 (internal | mcp)
|
||||
- version: 版本号
|
||||
- timeout_policy: 超时策略
|
||||
- auth_config: 鉴权配置
|
||||
- is_enabled: 启停状态
|
||||
"""
|
||||
name: str
|
||||
type: str = "internal"
|
||||
version: str = "1.0.0"
|
||||
description: str | None = None
|
||||
timeout_policy: ToolTimeoutPolicy = field(default_factory=ToolTimeoutPolicy)
|
||||
auth_config: ToolAuthConfig = field(default_factory=ToolAuthConfig)
|
||||
is_enabled: bool = True
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
"version": self.version,
|
||||
"description": self.description,
|
||||
"timeout_policy": self.timeout_policy.to_dict(),
|
||||
"auth_config": self.auth_config.to_dict(),
|
||||
"is_enabled": self.is_enabled,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolDefinition":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
type=data.get("type", "internal"),
|
||||
version=data.get("version", "1.0.0"),
|
||||
description=data.get("description"),
|
||||
timeout_policy=ToolTimeoutPolicy.from_dict(data.get("timeout_policy", {})),
|
||||
auth_config=ToolAuthConfig.from_dict(data.get("auth_config", {})),
|
||||
is_enabled=data.get("is_enabled", True),
|
||||
metadata=data.get("metadata", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.utcnow(),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class ToolRegistryEntity(SQLModel, table=True):
|
||||
"""
|
||||
[AC-IDMP-19] 工具注册表数据库实体
|
||||
支持动态配置更新
|
||||
"""
|
||||
__tablename__ = "tool_registry"
|
||||
__table_args__ = (
|
||||
Index("ix_tool_registry_tenant_name", "tenant_id", "name", unique=True),
|
||||
Index("ix_tool_registry_tenant_enabled", "tenant_id", "is_enabled"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="租户ID", index=True)
|
||||
name: str = Field(..., description="工具名称", max_length=128)
|
||||
type: str = Field(default="internal", description="工具类型: internal | mcp")
|
||||
version: str = Field(default="1.0.0", description="版本号", max_length=32)
|
||||
description: str | None = Field(default=None, description="工具描述")
|
||||
timeout_policy: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("timeout_policy", JSON, nullable=True),
|
||||
description="超时策略配置"
|
||||
)
|
||||
auth_config: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("auth_config", JSON, nullable=True),
|
||||
description="鉴权配置"
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="是否启用")
|
||||
metadata_: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("metadata", JSON, nullable=True),
|
||||
description="扩展元数据"
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||
|
||||
def to_definition(self) -> ToolDefinition:
|
||||
"""转换为内存模型"""
|
||||
return ToolDefinition(
|
||||
name=self.name,
|
||||
type=self.type,
|
||||
version=self.version,
|
||||
description=self.description,
|
||||
timeout_policy=ToolTimeoutPolicy.from_dict(self.timeout_policy or {}),
|
||||
auth_config=ToolAuthConfig.from_dict(self.auth_config or {}),
|
||||
is_enabled=self.is_enabled,
|
||||
metadata=self.metadata_ or {},
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class ToolRegistryCreate(SQLModel):
|
||||
"""创建工具注册请求"""
|
||||
name: str = Field(..., max_length=128)
|
||||
type: str = "internal"
|
||||
version: str = "1.0.0"
|
||||
description: str | None = None
|
||||
timeout_policy: dict[str, Any] | None = None
|
||||
auth_config: dict[str, Any] | None = None
|
||||
is_enabled: bool = True
|
||||
metadata_: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolRegistryUpdate(SQLModel):
|
||||
"""更新工具注册请求"""
|
||||
type: str | None = None
|
||||
version: str | None = None
|
||||
description: str | None = None
|
||||
timeout_policy: dict[str, Any] | None = None
|
||||
auth_config: dict[str, Any] | None = None
|
||||
is_enabled: bool | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Tool trace models for Mid Platform.
|
||||
[AC-IDMP-15] 工具调用结构化记录
|
||||
|
||||
Reference: spec/intent-driven-mid-platform/openapi.provider.yaml - ToolCallTrace
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolCallStatus(str, Enum):
|
||||
"""工具调用状态"""
|
||||
OK = "ok"
|
||||
TIMEOUT = "timeout"
|
||||
ERROR = "error"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ToolType(str, Enum):
|
||||
"""工具类型"""
|
||||
INTERNAL = "internal"
|
||||
MCP = "mcp"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallTrace:
|
||||
"""
|
||||
[AC-IDMP-15] 工具调用追踪记录
|
||||
Reference: openapi.provider.yaml - ToolCallTrace
|
||||
|
||||
记录字段:
|
||||
- tool_name: 工具名称
|
||||
- tool_type: 工具类型 (internal | mcp)
|
||||
- registry_version: 注册表版本
|
||||
- auth_applied: 是否应用鉴权
|
||||
- duration_ms: 调用耗时(毫秒)
|
||||
- status: 调用状态 (ok | timeout | error | rejected)
|
||||
- error_code: 错误码
|
||||
- args_digest: 参数摘要(脱敏)
|
||||
- result_digest: 结果摘要
|
||||
"""
|
||||
tool_name: str
|
||||
duration_ms: int
|
||||
status: ToolCallStatus
|
||||
tool_type: ToolType = ToolType.INTERNAL
|
||||
registry_version: str | None = None
|
||||
auth_applied: bool = False
|
||||
error_code: str | None = None
|
||||
args_digest: str | None = None
|
||||
result_digest: str | None = None
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {
|
||||
"tool_name": self.tool_name,
|
||||
"duration_ms": self.duration_ms,
|
||||
"status": self.status.value,
|
||||
}
|
||||
if self.tool_type != ToolType.INTERNAL:
|
||||
result["tool_type"] = self.tool_type.value
|
||||
if self.registry_version:
|
||||
result["registry_version"] = self.registry_version
|
||||
if self.auth_applied:
|
||||
result["auth_applied"] = self.auth_applied
|
||||
if self.error_code:
|
||||
result["error_code"] = self.error_code
|
||||
if self.args_digest:
|
||||
result["args_digest"] = self.args_digest
|
||||
if self.result_digest:
|
||||
result["result_digest"] = self.result_digest
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def compute_digest(data: Any, max_length: int = 64) -> str:
|
||||
"""
|
||||
计算数据摘要(用于脱敏记录)
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
max_length: 最大长度限制
|
||||
|
||||
Returns:
|
||||
摘要字符串
|
||||
"""
|
||||
if data is None:
|
||||
return ""
|
||||
|
||||
if isinstance(data, (dict, list)):
|
||||
data_str = json.dumps(data, ensure_ascii=False, sort_keys=True)
|
||||
else:
|
||||
data_str = str(data)
|
||||
|
||||
if len(data_str) <= max_length:
|
||||
return data_str
|
||||
|
||||
hash_value = hashlib.sha256(data_str.encode("utf-8")).hexdigest()[:16]
|
||||
preview = data_str[:32]
|
||||
return f"{preview}...[hash:{hash_value}]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallBuilder:
|
||||
"""
|
||||
[AC-IDMP-15] 工具调用记录构建器
|
||||
用于在工具执行过程中逐步构建追踪记录
|
||||
"""
|
||||
tool_name: str
|
||||
tool_type: ToolType = ToolType.INTERNAL
|
||||
registry_version: str | None = None
|
||||
auth_applied: bool = False
|
||||
_started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
_args: Any = None
|
||||
_result: Any = None
|
||||
_error: Exception | None = None
|
||||
_status: ToolCallStatus = ToolCallStatus.OK
|
||||
_error_code: str | None = None
|
||||
|
||||
def with_args(self, args: Any) -> "ToolCallBuilder":
|
||||
"""设置调用参数"""
|
||||
self._args = args
|
||||
return self
|
||||
|
||||
def with_registry_info(self, version: str, auth_applied: bool) -> "ToolCallBuilder":
|
||||
"""设置注册表信息"""
|
||||
self.registry_version = version
|
||||
self.auth_applied = auth_applied
|
||||
return self
|
||||
|
||||
def with_result(self, result: Any) -> "ToolCallBuilder":
|
||||
"""设置调用结果"""
|
||||
self._result = result
|
||||
self._status = ToolCallStatus.OK
|
||||
return self
|
||||
|
||||
def with_error(self, error: Exception, error_code: str | None = None) -> "ToolCallBuilder":
|
||||
"""设置错误信息"""
|
||||
self._error = error
|
||||
self._error_code = error_code
|
||||
if isinstance(error, TimeoutError):
|
||||
self._status = ToolCallStatus.TIMEOUT
|
||||
else:
|
||||
self._status = ToolCallStatus.ERROR
|
||||
return self
|
||||
|
||||
def with_rejected(self, reason: str) -> "ToolCallBuilder":
|
||||
"""设置拒绝状态"""
|
||||
self._status = ToolCallStatus.REJECTED
|
||||
self._error_code = reason
|
||||
return self
|
||||
|
||||
def build(self) -> ToolCallTrace:
|
||||
"""构建追踪记录"""
|
||||
completed_at = datetime.utcnow()
|
||||
duration_ms = int((completed_at - self._started_at).total_seconds() * 1000)
|
||||
|
||||
return ToolCallTrace(
|
||||
tool_name=self.tool_name,
|
||||
tool_type=self.tool_type,
|
||||
registry_version=self.registry_version,
|
||||
auth_applied=self.auth_applied,
|
||||
duration_ms=duration_ms,
|
||||
status=self._status,
|
||||
error_code=self._error_code,
|
||||
args_digest=ToolCallTrace.compute_digest(self._args) if self._args else None,
|
||||
result_digest=ToolCallTrace.compute_digest(self._result) if self._result else None,
|
||||
started_at=self._started_at,
|
||||
completed_at=completed_at,
|
||||
)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""
|
||||
Mid Platform services.
|
||||
[AC-IDMP-02, AC-IDMP-05, AC-IDMP-11, AC-IDMP-12, AC-IDMP-13, AC-IDMP-14, AC-IDMP-15, AC-IDMP-16, AC-IDMP-17, AC-IDMP-19, AC-IDMP-20]
|
||||
[AC-MARH-05, AC-MARH-06, AC-MARH-10, AC-MARH-11, AC-MARH-12]
|
||||
"""
|
||||
|
||||
from .policy_router import PolicyRouter, PolicyRouterResult, IntentMatch
|
||||
from .agent_orchestrator import AgentOrchestrator, ReActContext
|
||||
from .timeout_governor import TimeoutGovernor
|
||||
from .feature_flags import FeatureFlagService
|
||||
from .high_risk_handler import HighRiskHandler, HighRiskMatch
|
||||
from .trace_logger import TraceLogger, AuditRecord
|
||||
from .metrics_collector import MetricsCollector, SessionMetrics, AggregatedMetrics
|
||||
from .tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry
|
||||
from .tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder
|
||||
from .memory_adapter import MemoryAdapter, UserMemory
|
||||
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
|
||||
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
|
||||
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
|
||||
|
||||
__all__ = [
|
||||
"PolicyRouter",
|
||||
"PolicyRouterResult",
|
||||
"IntentMatch",
|
||||
"AgentOrchestrator",
|
||||
"ReActContext",
|
||||
"TimeoutGovernor",
|
||||
"FeatureFlagService",
|
||||
"HighRiskHandler",
|
||||
"HighRiskMatch",
|
||||
"TraceLogger",
|
||||
"AuditRecord",
|
||||
"MetricsCollector",
|
||||
"SessionMetrics",
|
||||
"AggregatedMetrics",
|
||||
"ToolRegistry",
|
||||
"ToolDefinition",
|
||||
"ToolExecutionResult",
|
||||
"get_tool_registry",
|
||||
"init_tool_registry",
|
||||
"ToolCallRecorder",
|
||||
"ToolCallStatistics",
|
||||
"get_tool_call_recorder",
|
||||
"MemoryAdapter",
|
||||
"UserMemory",
|
||||
"DefaultKbToolRunner",
|
||||
"KbToolResult",
|
||||
"KbToolConfig",
|
||||
"get_default_kb_tool_runner",
|
||||
"SegmentHumanizer",
|
||||
"HumanizeConfig",
|
||||
"LengthBucket",
|
||||
"get_segment_humanizer",
|
||||
"RuntimeObserver",
|
||||
"RuntimeContext",
|
||||
"get_runtime_observer",
|
||||
]
|
||||
|
|
@ -0,0 +1,370 @@
|
|||
"""
|
||||
Agent Orchestrator for Mid Platform.
|
||||
[AC-MARH-07] ReAct loop with iteration limit (3-5 iterations).
|
||||
|
||||
ReAct Flow:
|
||||
1. Thought: Agent thinks about what to do
|
||||
2. Action: Agent decides to use a tool
|
||||
3. Observation: Tool execution result
|
||||
4. Repeat until final answer or max iterations reached
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
ExecutionMode,
|
||||
ReActContext,
|
||||
ToolCallStatus,
|
||||
ToolCallTrace,
|
||||
ToolType,
|
||||
TraceInfo,
|
||||
)
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_ITERATIONS = 5
|
||||
MIN_ITERATIONS = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Tool execution result."""
|
||||
success: bool
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentThought:
|
||||
"""Agent thought in ReAct loop."""
|
||||
content: str
|
||||
action: str | None = None
|
||||
action_input: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentOrchestrator:
|
||||
"""
|
||||
[AC-MARH-07] Agent orchestrator with ReAct loop control.
|
||||
|
||||
Features:
|
||||
- ReAct loop with max 5 iterations (min 3)
|
||||
- Per-tool timeout (2s) and end-to-end timeout (8s)
|
||||
- Automatic fallback on iteration limit or timeout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
llm_client: Any = None,
|
||||
tool_registry: Any = None,
|
||||
):
|
||||
self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS)
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._llm_client = llm_client
|
||||
self._tool_registry = tool_registry
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
on_thought: Any = None,
|
||||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
[AC-MARH-07] Execute ReAct loop with iteration control.
|
||||
|
||||
Args:
|
||||
user_message: User input message
|
||||
context: Execution context (history, retrieval results, etc.)
|
||||
on_thought: Callback for thought events
|
||||
on_action: Callback for action events
|
||||
|
||||
Returns:
|
||||
Tuple of (final_answer, react_context, trace_info)
|
||||
"""
|
||||
react_ctx = ReActContext(max_iterations=self._max_iterations)
|
||||
tool_calls: list[ToolCallTrace] = []
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Starting ReAct loop: max_iterations={self._max_iterations}"
|
||||
)
|
||||
|
||||
try:
|
||||
overall_start = time.time()
|
||||
end_to_end_timeout = self._timeout_governor.end_to_end_timeout_seconds
|
||||
llm_timeout = (
|
||||
self._timeout_governor.llm_timeout_seconds
|
||||
if hasattr(self._timeout_governor, 'llm_timeout_seconds')
|
||||
else 15.0
|
||||
)
|
||||
|
||||
while react_ctx.should_continue and react_ctx.iteration < react_ctx.max_iterations:
|
||||
react_ctx.iteration += 1
|
||||
|
||||
elapsed = time.time() - overall_start
|
||||
remaining_time = end_to_end_timeout - elapsed
|
||||
if remaining_time <= 0:
|
||||
logger.warning(
|
||||
"[AC-MARH-09] ReAct loop exceeded end-to-end timeout"
|
||||
)
|
||||
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] ReAct iteration {react_ctx.iteration}/"
|
||||
f"{react_ctx.max_iterations}, remaining_time={remaining_time:.1f}s"
|
||||
)
|
||||
|
||||
thought = await asyncio.wait_for(
|
||||
self._think(user_message, context, react_ctx),
|
||||
timeout=min(llm_timeout, remaining_time)
|
||||
)
|
||||
if on_thought:
|
||||
await on_thought(thought)
|
||||
|
||||
if not thought.action:
|
||||
logger.info(
|
||||
f"[AC-MARH-07] No action, setting final_answer: "
|
||||
f"{thought.content[:200] if thought.content else 'None'}"
|
||||
)
|
||||
react_ctx.final_answer = thought.content
|
||||
react_ctx.should_continue = False
|
||||
break
|
||||
|
||||
tool_result, tool_trace = await self._act(thought, react_ctx)
|
||||
tool_calls.append(tool_trace)
|
||||
react_ctx.tool_calls.append(tool_trace)
|
||||
|
||||
if on_action:
|
||||
await on_action(thought.action, tool_result)
|
||||
|
||||
if tool_result.success:
|
||||
context = context or {}
|
||||
context["last_observation"] = tool_result.output
|
||||
else:
|
||||
if tool_trace.status == ToolCallStatus.TIMEOUT:
|
||||
logger.warning(f"[AC-MARH-08] Tool timeout: {thought.action}")
|
||||
react_ctx.final_answer = "抱歉,操作超时,请稍后重试或联系人工客服。"
|
||||
react_ctx.should_continue = False
|
||||
break
|
||||
|
||||
if react_ctx.should_continue and not react_ctx.final_answer:
|
||||
logger.warning(f"[AC-MARH-07] ReAct reached max iterations: {react_ctx.iteration}")
|
||||
react_ctx.final_answer = await self._force_final_answer(user_message, context, react_ctx)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[AC-MARH-09] ReAct loop timed out (end-to-end)")
|
||||
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
|
||||
tool_calls.append(ToolCallTrace(
|
||||
tool_name="react_loop",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="E2E_TIMEOUT",
|
||||
))
|
||||
|
||||
total_duration_ms = int((time.time() - start_time) * 1000)
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
react_iterations=react_ctx.iteration,
|
||||
tools_used=[tc.tool_name for tc in tool_calls if tc.tool_name != "react_loop"],
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] ReAct completed: iterations={react_ctx.iteration}, "
|
||||
f"duration_ms={total_duration_ms}"
|
||||
)
|
||||
|
||||
return react_ctx.final_answer or "抱歉,我暂时无法处理您的请求。", react_ctx, trace
|
||||
|
||||
async def _think(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None,
|
||||
react_ctx: ReActContext,
|
||||
) -> AgentThought:
|
||||
"""
|
||||
[AC-MARH-07] Agent thinks about next action.
|
||||
|
||||
In real implementation, this would call LLM with ReAct prompt.
|
||||
For now, returns a simple thought without action.
|
||||
"""
|
||||
if not self._llm_client:
|
||||
return AgentThought(content=f"思考中... 用户消息: {user_message}")
|
||||
|
||||
try:
|
||||
observations = []
|
||||
if context and "last_observation" in context:
|
||||
observations.append(f"上一步结果: {context['last_observation']}")
|
||||
|
||||
for tc in react_ctx.tool_calls[-3:]:
|
||||
observations.append(f"工具 {tc.tool_name}: {tc.result_digest or '无结果'}")
|
||||
|
||||
prompt = self._build_react_prompt(user_message, observations)
|
||||
response = await self._llm_client.generate([{"role": "user", "content": prompt}])
|
||||
|
||||
logger.info(f"[AC-MARH-07] LLM response content: {response.content[:500] if response.content else 'None'}")
|
||||
|
||||
return self._parse_thought(response.content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-07] Think failed: {e}")
|
||||
return AgentThought(content=f"思考失败: {str(e)}")
|
||||
|
||||
def _build_react_prompt(self, user_message: str, observations: list[str]) -> str:
|
||||
"""Build ReAct prompt for LLM."""
|
||||
obs_text = "\n".join(observations) if observations else "无"
|
||||
return f"""你是一个智能助手,正在使用 ReAct 模式处理用户请求。
|
||||
|
||||
用户消息: {user_message}
|
||||
|
||||
历史观察:
|
||||
{obs_text}
|
||||
|
||||
请思考下一步行动。如果已经有足够信息回答用户,请直接给出最终答案。
|
||||
如果需要使用工具,请按以下格式回复:
|
||||
Thought: [你的思考]
|
||||
Action: [工具名称]
|
||||
Action Input: {{"param1": "value1"}}
|
||||
"""
|
||||
|
||||
def _parse_thought(self, content: str) -> AgentThought:
|
||||
"""Parse LLM response into AgentThought."""
|
||||
action = None
|
||||
action_input = None
|
||||
|
||||
if "Action:" in content:
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("Action:"):
|
||||
action = line.replace("Action:", "").strip()
|
||||
elif line.startswith("Action Input:"):
|
||||
import json
|
||||
try:
|
||||
action_input = json.loads(line.replace("Action Input:", "").strip())
|
||||
except json.JSONDecodeError:
|
||||
action_input = {}
|
||||
|
||||
return AgentThought(content=content, action=action, action_input=action_input)
|
||||
|
||||
async def _act(
|
||||
self,
|
||||
thought: AgentThought,
|
||||
react_ctx: ReActContext,
|
||||
) -> tuple[ToolResult, ToolCallTrace]:
|
||||
"""
|
||||
[AC-MARH-07, AC-MARH-08] Execute tool action with timeout.
|
||||
"""
|
||||
tool_name = thought.action or "unknown"
|
||||
start_time = time.time()
|
||||
|
||||
if not self._tool_registry:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Tool registry not configured",
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="NO_REGISTRY",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._tool_registry.execute(
|
||||
tool_name=tool_name,
|
||||
args=thought.action_input or {},
|
||||
),
|
||||
timeout=self._timeout_governor.per_tool_timeout_seconds
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
return ToolResult(
|
||||
success=result.get("success", False),
|
||||
output=result.get("output"),
|
||||
error=result.get("error"),
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK if result.get("success") else ToolCallStatus.ERROR,
|
||||
args_digest=str(thought.action_input)[:100] if thought.action_input else None,
|
||||
result_digest=str(result.get("output"))[:100] if result.get("output") else None,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[AC-MARH-08] Tool timeout: {tool_name}, duration={duration_ms}ms")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Tool timeout",
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[AC-MARH-07] Tool error: {tool_name}, error={e}")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="TOOL_ERROR",
|
||||
)
|
||||
|
||||
async def _force_final_answer(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None,
|
||||
react_ctx: ReActContext,
|
||||
) -> str:
|
||||
"""Force final answer when max iterations reached."""
|
||||
observations = []
|
||||
for tc in react_ctx.tool_calls:
|
||||
if tc.result_digest:
|
||||
observations.append(f"- {tc.tool_name}: {tc.result_digest}")
|
||||
|
||||
obs_text = "\n".join(observations) if observations else "无"
|
||||
|
||||
if self._llm_client:
|
||||
try:
|
||||
prompt = f"""基于以下信息,请给出最终回答:
|
||||
|
||||
用户消息: {user_message}
|
||||
|
||||
收集到的信息:
|
||||
{obs_text}
|
||||
|
||||
请直接给出回答,不要再调用工具。"""
|
||||
response = await self._llm_client.generate([{"role": "user", "content": prompt}])
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-07] Force final answer failed: {e}")
|
||||
|
||||
return "抱歉,我已经尽力处理您的请求,但可能需要更多信息。请稍后重试或联系人工客服。"
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
Default KB Tool Runner for Mid Platform.
|
||||
[AC-MARH-05] Agent 默认 KB 检索工具调用。
|
||||
[AC-MARH-06] KB 失败时可观测降级。
|
||||
|
||||
当 Agent 模式处理开放咨询时,默认尝试调用 KB 检索工具获取事实依据。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_KB_TOP_K = 5
|
||||
DEFAULT_KB_TIMEOUT_MS = 2000
|
||||
|
||||
|
||||
@dataclass
|
||||
class KbToolResult:
|
||||
"""KB 检索结果。"""
|
||||
success: bool
|
||||
hits: list[dict[str, Any]] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
fallback_reason_code: str | None = None
|
||||
duration_ms: int = 0
|
||||
tool_trace: ToolCallTrace | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KbToolConfig:
|
||||
"""KB 工具配置。"""
|
||||
enabled: bool = True
|
||||
top_k: int = DEFAULT_KB_TOP_K
|
||||
timeout_ms: int = DEFAULT_KB_TIMEOUT_MS
|
||||
min_score_threshold: float = 0.5
|
||||
|
||||
|
||||
class DefaultKbToolRunner:
|
||||
"""
|
||||
[AC-MARH-05] Agent 默认 KB 检索工具执行器。
|
||||
|
||||
Features:
|
||||
- Agent 模式下默认调用 KB 检索
|
||||
- 支持超时控制
|
||||
- 失败时返回可观测降级信号
|
||||
- 记录 kb_tool_called, kb_hit 状态
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbToolConfig | None = None,
|
||||
):
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or KbToolConfig()
|
||||
self._vector_retriever = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> KbToolResult:
|
||||
"""
|
||||
[AC-MARH-05] 执行 KB 检索。
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
query: 检索查询
|
||||
metadata_filter: 元数据过滤条件
|
||||
|
||||
Returns:
|
||||
KbToolResult 包含检索结果和追踪信息
|
||||
"""
|
||||
if not self._config.enabled:
|
||||
logger.info(f"[AC-MARH-05] KB tool disabled for tenant={tenant_id}")
|
||||
return KbToolResult(
|
||||
success=False,
|
||||
fallback_reason_code="KB_DISABLED",
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
tool_trace_id = str(uuid.uuid4())
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Starting KB retrieval: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..., top_k={self._config.top_k}"
|
||||
)
|
||||
|
||||
try:
|
||||
hits = await self._retrieve_with_timeout(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
kb_hit = len(hits) > 0
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name="kb_retrieval",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK,
|
||||
args_digest=f"query={query[:50]}",
|
||||
result_digest=f"hits={len(hits)}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] KB retrieval completed: tenant={tenant_id}, "
|
||||
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
|
||||
)
|
||||
|
||||
return KbToolResult(
|
||||
success=True,
|
||||
hits=hits,
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
except TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-MARH-06] KB retrieval timeout: tenant={tenant_id}, "
|
||||
f"duration_ms={duration_ms}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name="kb_retrieval",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="KB_TIMEOUT",
|
||||
)
|
||||
|
||||
return KbToolResult(
|
||||
success=False,
|
||||
error="KB retrieval timeout",
|
||||
fallback_reason_code="KB_TIMEOUT",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-MARH-06] KB retrieval failed: tenant={tenant_id}, "
|
||||
f"error={e}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name="kb_retrieval",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="KB_ERROR",
|
||||
)
|
||||
|
||||
return KbToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
fallback_reason_code="KB_ERROR",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
async def _retrieve_with_timeout(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""带超时控制的检索。"""
|
||||
import asyncio
|
||||
|
||||
timeout_seconds = self._config.timeout_ms / 1000.0
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._do_retrieve(tenant_id, query, metadata_filter),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError("KB retrieval timeout")
|
||||
|
||||
async def _do_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""执行实际检索。"""
|
||||
if self._vector_retriever is None:
|
||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||
self._vector_retriever = await get_vector_retriever()
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
result = await self._vector_retriever.retrieve(ctx)
|
||||
|
||||
hits = []
|
||||
for hit in result.hits:
|
||||
if hit.score >= self._config.min_score_threshold:
|
||||
hits.append({
|
||||
"id": hit.metadata.get("chunk_id", str(uuid.uuid4())),
|
||||
"content": hit.text,
|
||||
"score": hit.score,
|
||||
"metadata": hit.metadata,
|
||||
})
|
||||
|
||||
return hits[:self._config.top_k]
|
||||
|
||||
def get_config(self) -> KbToolConfig:
|
||||
"""获取当前配置。"""
|
||||
return self._config
|
||||
|
||||
|
||||
_default_kb_tool_runner: DefaultKbToolRunner | None = None
|
||||
|
||||
|
||||
def get_default_kb_tool_runner(
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbToolConfig | None = None,
|
||||
) -> DefaultKbToolRunner:
|
||||
"""获取或创建 DefaultKbToolRunner 实例。"""
|
||||
global _default_kb_tool_runner
|
||||
if _default_kb_tool_runner is None:
|
||||
_default_kb_tool_runner = DefaultKbToolRunner(
|
||||
timeout_governor=timeout_governor,
|
||||
config=config,
|
||||
)
|
||||
return _default_kb_tool_runner
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Feature Flags Service for Mid Platform.
|
||||
[AC-IDMP-17] Session-level grayscale and rollback support.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import FeatureFlags
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureFlagConfig:
|
||||
"""Feature flag configuration."""
|
||||
agent_enabled: bool = True
|
||||
rollback_to_legacy: bool = False
|
||||
react_max_iterations: int = 5
|
||||
enable_tool_registry: bool = True
|
||||
enable_trace_logging: bool = True
|
||||
|
||||
|
||||
class FeatureFlagService:
|
||||
"""
|
||||
[AC-IDMP-17] Feature flag service for session-level control.
|
||||
|
||||
Supports:
|
||||
- Session-level Agent enable/disable
|
||||
- Force rollback to legacy pipeline
|
||||
- Dynamic configuration per session
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._session_flags: dict[str, FeatureFlagConfig] = {}
|
||||
self._global_config = FeatureFlagConfig()
|
||||
|
||||
def get_flags(self, session_id: str) -> FeatureFlags:
|
||||
"""
|
||||
[AC-IDMP-17] Get feature flags for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
FeatureFlags for the session
|
||||
"""
|
||||
config = self._session_flags.get(session_id, self._global_config)
|
||||
|
||||
return FeatureFlags(
|
||||
agent_enabled=config.agent_enabled,
|
||||
rollback_to_legacy=config.rollback_to_legacy,
|
||||
)
|
||||
|
||||
def set_flags(self, session_id: str, flags: FeatureFlags) -> None:
|
||||
"""
|
||||
[AC-IDMP-17] Set feature flags for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
flags: Feature flags to set
|
||||
"""
|
||||
config = FeatureFlagConfig(
|
||||
agent_enabled=flags.agent_enabled if flags.agent_enabled is not None else self._global_config.agent_enabled,
|
||||
rollback_to_legacy=flags.rollback_to_legacy if flags.rollback_to_legacy is not None else self._global_config.rollback_to_legacy,
|
||||
)
|
||||
|
||||
self._session_flags[session_id] = config
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-17] Feature flags set for session {session_id}: "
|
||||
f"agent_enabled={config.agent_enabled}, rollback_to_legacy={config.rollback_to_legacy}"
|
||||
)
|
||||
|
||||
def clear_flags(self, session_id: str) -> None:
|
||||
"""Clear feature flags for a session."""
|
||||
if session_id in self._session_flags:
|
||||
del self._session_flags[session_id]
|
||||
logger.info(f"[AC-IDMP-17] Feature flags cleared for session {session_id}")
|
||||
|
||||
def is_agent_enabled(self, session_id: str) -> bool:
|
||||
"""Check if Agent mode is enabled for a session."""
|
||||
config = self._session_flags.get(session_id, self._global_config)
|
||||
return config.agent_enabled
|
||||
|
||||
def should_rollback(self, session_id: str) -> bool:
|
||||
"""Check if should rollback to legacy for a session."""
|
||||
config = self._session_flags.get(session_id, self._global_config)
|
||||
return config.rollback_to_legacy
|
||||
|
||||
def set_global_config(self, config: FeatureFlagConfig) -> None:
|
||||
"""Set global default configuration."""
|
||||
self._global_config = config
|
||||
logger.info(f"[AC-IDMP-17] Global config updated: {config}")
|
||||
|
||||
def get_react_max_iterations(self, session_id: str) -> int:
|
||||
"""Get ReAct max iterations for a session."""
|
||||
config = self._session_flags.get(session_id, self._global_config)
|
||||
return config.react_max_iterations
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
"""
|
||||
High Risk Handler for Mid Platform.
|
||||
[AC-IDMP-05, AC-IDMP-20] High-risk scenario detection and mandatory takeover.
|
||||
|
||||
High-Risk Scenarios (minimum set):
|
||||
1. Refund (退款)
|
||||
2. Complaint Escalation (投诉升级)
|
||||
3. Privacy/Sensitive Promise (隐私与敏感承诺)
|
||||
4. Transfer (转人工)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
ExecutionMode,
|
||||
HighRiskScenario,
|
||||
PolicyRouterResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HIGH_RISK_SCENARIOS = [
|
||||
HighRiskScenario.REFUND,
|
||||
HighRiskScenario.COMPLAINT_ESCALATION,
|
||||
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE,
|
||||
HighRiskScenario.TRANSFER,
|
||||
]
|
||||
|
||||
HIGH_RISK_PATTERNS = {
|
||||
HighRiskScenario.REFUND: [
|
||||
r"退款",
|
||||
r"退货",
|
||||
r"退钱",
|
||||
r"退费",
|
||||
r"还钱",
|
||||
r"申请退款",
|
||||
r"我要退",
|
||||
],
|
||||
HighRiskScenario.COMPLAINT_ESCALATION: [
|
||||
r"投诉",
|
||||
r"升级投诉",
|
||||
r"举报",
|
||||
r"12315",
|
||||
r"消费者协会",
|
||||
r"工商局",
|
||||
r"市场监督管理局",
|
||||
],
|
||||
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: [
|
||||
r"承诺.{0,10}(退款|赔偿|补偿)",
|
||||
r"保证.{0,10}(退款|赔偿|补偿)",
|
||||
r"一定.{0,10}(退款|赔偿|补偿)",
|
||||
r"肯定能.{0,10}(退款|赔偿|补偿)",
|
||||
r"绝对.{0,10}(退款|赔偿|补偿)",
|
||||
r"担保",
|
||||
],
|
||||
HighRiskScenario.TRANSFER: [
|
||||
r"转人工",
|
||||
r"人工客服",
|
||||
r"人工服务",
|
||||
r"真人",
|
||||
r"人工",
|
||||
r"转接人工",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class HighRiskMatch:
|
||||
"""High-risk scenario match result."""
|
||||
scenario: HighRiskScenario
|
||||
matched_pattern: str
|
||||
matched_text: str
|
||||
confidence: float = 1.0
|
||||
|
||||
|
||||
class HighRiskHandler:
|
||||
"""
|
||||
[AC-IDMP-05, AC-IDMP-20] High-risk scenario handler.
|
||||
|
||||
Features:
|
||||
- Configurable high-risk scenario set (minimum set required)
|
||||
- Pattern-based detection
|
||||
- Mandatory takeover to micro_flow or transfer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled_scenarios: list[HighRiskScenario] | None = None,
|
||||
custom_patterns: dict[HighRiskScenario, list[str]] | None = None,
|
||||
):
|
||||
self._enabled_scenarios = enabled_scenarios or DEFAULT_HIGH_RISK_SCENARIOS.copy()
|
||||
|
||||
if not self._enabled_scenarios:
|
||||
raise ValueError("[AC-IDMP-20] High-risk scenario set cannot be empty")
|
||||
|
||||
self._patterns = HIGH_RISK_PATTERNS.copy()
|
||||
if custom_patterns:
|
||||
for scenario, patterns in custom_patterns.items():
|
||||
if scenario not in self._patterns:
|
||||
self._patterns[scenario] = []
|
||||
self._patterns[scenario].extend(patterns)
|
||||
|
||||
self._compiled_patterns = self._compile_patterns()
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-20] HighRiskHandler initialized with scenarios: "
|
||||
f"{[s.value for s in self._enabled_scenarios]}"
|
||||
)
|
||||
|
||||
def _compile_patterns(self) -> dict[HighRiskScenario, list]:
|
||||
"""Compile regex patterns for better performance."""
|
||||
import re
|
||||
compiled = {}
|
||||
for scenario in self._enabled_scenarios:
|
||||
patterns = self._patterns.get(scenario, [])
|
||||
compiled[scenario] = [
|
||||
re.compile(p, re.IGNORECASE) for p in patterns
|
||||
]
|
||||
return compiled
|
||||
|
||||
def detect(self, message: str) -> HighRiskMatch | None:
|
||||
"""
|
||||
[AC-IDMP-05, AC-IDMP-20] Detect high-risk scenario in message.
|
||||
|
||||
Args:
|
||||
message: User message to check
|
||||
|
||||
Returns:
|
||||
HighRiskMatch if detected, None otherwise
|
||||
"""
|
||||
for scenario in self._enabled_scenarios:
|
||||
patterns = self._compiled_patterns.get(scenario, [])
|
||||
for pattern in patterns:
|
||||
match = pattern.search(message)
|
||||
if match:
|
||||
logger.info(
|
||||
f"[AC-IDMP-05] High-risk scenario detected: "
|
||||
f"scenario={scenario.value}, pattern={pattern.pattern}"
|
||||
)
|
||||
return HighRiskMatch(
|
||||
scenario=scenario,
|
||||
matched_pattern=pattern.pattern,
|
||||
matched_text=match.group(),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def handle(
|
||||
self,
|
||||
match: HighRiskMatch,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-05] Handle high-risk scenario by routing to appropriate mode.
|
||||
|
||||
Args:
|
||||
match: High-risk match result
|
||||
context: Additional context (intent match, flow config, etc.)
|
||||
|
||||
Returns:
|
||||
PolicyRouterResult with execution mode decision
|
||||
"""
|
||||
context = context or {}
|
||||
|
||||
if match.scenario == HighRiskScenario.TRANSFER:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
high_risk_triggered=True,
|
||||
transfer_message="正在为您转接人工客服,请稍候...",
|
||||
)
|
||||
|
||||
flow_id = context.get("flow_id")
|
||||
if flow_id:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
high_risk_triggered=True,
|
||||
target_flow_id=flow_id,
|
||||
)
|
||||
|
||||
if match.scenario == HighRiskScenario.REFUND:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
high_risk_triggered=True,
|
||||
fallback_reason_code="high_risk_refund",
|
||||
)
|
||||
|
||||
if match.scenario == HighRiskScenario.COMPLAINT_ESCALATION:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
high_risk_triggered=True,
|
||||
transfer_message="检测到您可能需要投诉处理,正在为您转接人工客服...",
|
||||
)
|
||||
|
||||
if match.scenario == HighRiskScenario.PRIVACY_SENSITIVE_PROMISE:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
high_risk_triggered=True,
|
||||
fallback_reason_code="high_risk_privacy_promise",
|
||||
)
|
||||
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
high_risk_triggered=True,
|
||||
transfer_message="正在为您转接人工客服...",
|
||||
)
|
||||
|
||||
def get_enabled_scenarios(self) -> list[HighRiskScenario]:
|
||||
"""[AC-IDMP-20] Get enabled high-risk scenarios."""
|
||||
return self._enabled_scenarios.copy()
|
||||
|
||||
def add_scenario(self, scenario: HighRiskScenario) -> None:
|
||||
"""Add a high-risk scenario to the enabled set."""
|
||||
if scenario not in self._enabled_scenarios:
|
||||
self._enabled_scenarios.append(scenario)
|
||||
if scenario in HIGH_RISK_PATTERNS:
|
||||
self._compiled_patterns[scenario] = [
|
||||
__import__('re').compile(p, __import__('re').IGNORECASE)
|
||||
for p in HIGH_RISK_PATTERNS[scenario]
|
||||
]
|
||||
logger.info(f"[AC-IDMP-20] Added high-risk scenario: {scenario.value}")
|
||||
|
||||
def remove_scenario(self, scenario: HighRiskScenario) -> bool:
|
||||
"""Remove a high-risk scenario from the enabled set."""
|
||||
if scenario in self._enabled_scenarios and len(self._enabled_scenarios) > 1:
|
||||
self._enabled_scenarios.remove(scenario)
|
||||
if scenario in self._compiled_patterns:
|
||||
del self._compiled_patterns[scenario]
|
||||
logger.info(f"[AC-IDMP-20] Removed high-risk scenario: {scenario.value}")
|
||||
return True
|
||||
return False
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
"""
|
||||
Interrupt Context Enricher for Mid Platform.
|
||||
[AC-MARH-03, AC-MARH-04] Interrupted segments processing and fallback handling.
|
||||
|
||||
Features:
|
||||
- Consumes interrupted_segments from request
|
||||
- Provides context for re-planning
|
||||
- Fallback handling for invalid/empty interrupt data
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import InterruptedSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterruptContext:
|
||||
"""Context derived from interrupted segments."""
|
||||
consumed: bool = False
|
||||
interrupted_content: str | None = None
|
||||
interrupted_segment_ids: list[str] | None = None
|
||||
fallback_triggered: bool = False
|
||||
fallback_reason: str | None = None
|
||||
|
||||
|
||||
class InterruptContextEnricher:
|
||||
"""
|
||||
[AC-MARH-03, AC-MARH-04] Interrupt context enricher for handling user interruption.
|
||||
|
||||
This component processes interrupted_segments from the request and provides
|
||||
context for re-planning. It handles edge cases where interrupt data is
|
||||
invalid or empty, ensuring the main flow continues without disruption.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def enrich(
|
||||
self,
|
||||
interrupted_segments: list[InterruptedSegment] | None,
|
||||
generation_id: str | None = None,
|
||||
) -> InterruptContext:
|
||||
"""
|
||||
[AC-MARH-03, AC-MARH-04] Process interrupted segments and return context.
|
||||
|
||||
Args:
|
||||
interrupted_segments: List of interrupted segments from request
|
||||
generation_id: Current generation ID for matching
|
||||
|
||||
Returns:
|
||||
InterruptContext with processed interrupt information
|
||||
"""
|
||||
if not interrupted_segments:
|
||||
logger.debug("[AC-MARH-04] No interrupted segments provided")
|
||||
return InterruptContext(
|
||||
consumed=False,
|
||||
fallback_triggered=False,
|
||||
)
|
||||
|
||||
try:
|
||||
valid_segments = [
|
||||
s for s in interrupted_segments
|
||||
if s.content and s.content.strip() and s.segment_id
|
||||
]
|
||||
|
||||
if not valid_segments:
|
||||
logger.info("[AC-MARH-04] Interrupted segments empty or invalid, using fallback")
|
||||
return InterruptContext(
|
||||
consumed=False,
|
||||
fallback_triggered=True,
|
||||
fallback_reason="empty_or_invalid_segments",
|
||||
)
|
||||
|
||||
interrupted_content = "\n".join(s.content for s in valid_segments)
|
||||
segment_ids = [s.segment_id for s in valid_segments]
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-03] Interrupted segments consumed: "
|
||||
f"count={len(valid_segments)}, segment_ids={segment_ids[:3]}..."
|
||||
)
|
||||
|
||||
return InterruptContext(
|
||||
consumed=True,
|
||||
interrupted_content=interrupted_content,
|
||||
interrupted_segment_ids=segment_ids,
|
||||
fallback_triggered=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-04] Failed to process interrupted segments: {e}")
|
||||
return InterruptContext(
|
||||
consumed=False,
|
||||
fallback_triggered=True,
|
||||
fallback_reason=f"error:{str(e)[:50]}",
|
||||
)
|
||||
|
||||
def build_replan_context(
|
||||
self,
|
||||
interrupt_context: InterruptContext,
|
||||
base_context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-MARH-03] Build context for re-planning after interruption.
|
||||
|
||||
Args:
|
||||
interrupt_context: Processed interrupt context
|
||||
base_context: Existing context to enrich
|
||||
|
||||
Returns:
|
||||
Enriched context for re-planning
|
||||
"""
|
||||
context = base_context.copy() if base_context else {}
|
||||
|
||||
if interrupt_context.consumed and interrupt_context.interrupted_content:
|
||||
context["interrupted_content"] = interrupt_context.interrupted_content
|
||||
context["interrupted_segment_ids"] = interrupt_context.interrupted_segment_ids
|
||||
context["avoid_duplicate"] = True
|
||||
logger.debug(
|
||||
f"[AC-MARH-03] Re-plan context enriched with interrupted content: "
|
||||
f"{len(interrupt_context.interrupted_content)} chars"
|
||||
)
|
||||
elif interrupt_context.fallback_triggered:
|
||||
context["interrupt_fallback"] = True
|
||||
context["interrupt_fallback_reason"] = interrupt_context.fallback_reason
|
||||
logger.debug(
|
||||
f"[AC-MARH-04] Re-plan context marked with fallback: "
|
||||
f"{interrupt_context.fallback_reason}"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def should_skip_content(
|
||||
self,
|
||||
content: str,
|
||||
interrupt_context: InterruptContext,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-MARH-03] Check if content should be skipped to avoid duplicate.
|
||||
|
||||
Args:
|
||||
content: Content to check
|
||||
interrupt_context: Processed interrupt context
|
||||
|
||||
Returns:
|
||||
True if content matches interrupted content and should be skipped
|
||||
"""
|
||||
if not interrupt_context.consumed or not interrupt_context.interrupted_content:
|
||||
return False
|
||||
|
||||
if content.strip() in interrupt_context.interrupted_content:
|
||||
logger.info(
|
||||
f"[AC-MARH-03] Content matches interrupted segment, skipping: "
|
||||
f"{content[:50]}..."
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
@ -0,0 +1,487 @@
|
|||
"""
|
||||
KB Search Dynamic Tool for Mid Platform.
|
||||
[AC-MARH-05] Agent 默认 KB 检索工具,支持元数据驱动参数。
|
||||
[AC-MARH-06] KB 失败时可观测降级。
|
||||
|
||||
核心特性:
|
||||
- 通过元数据配置动态生成检索参数/过滤器
|
||||
- 必填字段缺失时返回 missing_required_slots
|
||||
- 工具执行可观测(tool_call/tool_result)
|
||||
- 超时降级返回 fallback_reason_code
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
|
||||
from app.services.mid.metadata_filter_builder import (
|
||||
FilterBuildResult,
|
||||
MetadataFilterBuilder,
|
||||
)
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mid.tool_registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TOP_K = 5
|
||||
DEFAULT_TIMEOUT_MS = 2000
|
||||
KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class KbSearchDynamicResult:
|
||||
"""KB 动态检索结果。"""
|
||||
success: bool = True
|
||||
hits: list[dict[str, Any]] = field(default_factory=list)
|
||||
applied_filter: dict[str, Any] = field(default_factory=dict)
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
filter_debug: dict[str, Any] = field(default_factory=dict)
|
||||
fallback_reason_code: str | None = None
|
||||
duration_ms: int = 0
|
||||
tool_trace: ToolCallTrace | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KbSearchDynamicConfig:
|
||||
"""KB 动态检索配置。"""
|
||||
enabled: bool = True
|
||||
top_k: int = DEFAULT_TOP_K
|
||||
timeout_ms: int = DEFAULT_TIMEOUT_MS
|
||||
min_score_threshold: float = 0.5
|
||||
|
||||
|
||||
class KbSearchDynamicTool:
|
||||
"""
|
||||
[AC-MARH-05] KB 动态检索工具。
|
||||
|
||||
支持通过元数据配置动态生成检索参数/过滤器,而不是固定入参写死。
|
||||
|
||||
固定外壳入参:
|
||||
- query: 检索查询
|
||||
- scene: 场景标识
|
||||
- tenant_id: 租户 ID
|
||||
- top_k: 返回数量
|
||||
- context: 上下文(包含动态过滤值)
|
||||
|
||||
返回结构:
|
||||
- hits: 检索结果
|
||||
- applied_filter: 已应用的过滤条件
|
||||
- missing_required_slots: 缺失的必填字段
|
||||
- filter_debug: 过滤器调试信息
|
||||
- fallback_reason_code: 降级原因码
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbSearchDynamicConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or KbSearchDynamicConfig()
|
||||
self._vector_retriever = None
|
||||
self._filter_builder: MetadataFilterBuilder | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称。"""
|
||||
return KB_SEARCH_DYNAMIC_TOOL_NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述。"""
|
||||
return (
|
||||
"知识库动态检索工具。"
|
||||
"根据租户配置的元数据字段定义,动态构建检索过滤器。"
|
||||
"支持必填字段检测和可观测降级。"
|
||||
)
|
||||
|
||||
def get_tool_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取工具 Schema,用于 Agent 工具描述。
|
||||
动态生成基于租户配置的过滤字段。
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "检索查询文本",
|
||||
},
|
||||
"scene": {
|
||||
"type": "string",
|
||||
"description": "场景标识,如 'open_consult', 'intent_match'",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量",
|
||||
"default": DEFAULT_TOP_K,
|
||||
},
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "上下文信息,包含动态过滤字段值",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: str,
|
||||
tenant_id: str,
|
||||
scene: str = "open_consult",
|
||||
top_k: int | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> KbSearchDynamicResult:
|
||||
"""
|
||||
[AC-MARH-05] 执行 KB 动态检索。
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识
|
||||
top_k: 返回数量
|
||||
context: 上下文(包含动态过滤值)
|
||||
|
||||
Returns:
|
||||
KbSearchDynamicResult 包含检索结果和追踪信息
|
||||
"""
|
||||
if not self._config.enabled:
|
||||
logger.info(f"[AC-MARH-05] KB search dynamic disabled for tenant={tenant_id}")
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
fallback_reason_code="KB_DISABLED",
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
top_k = top_k or self._config.top_k
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..., scene={scene}, top_k={top_k}"
|
||||
)
|
||||
|
||||
filter_result: FilterBuildResult | None = None
|
||||
|
||||
try:
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filter_result = await self._filter_builder.build_filter(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if filter_result.missing_required_slots:
|
||||
logger.warning(
|
||||
f"[AC-MARH-05] Missing required slots: "
|
||||
f"{filter_result.missing_required_slots}"
|
||||
)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={scene}",
|
||||
result_digest=f"missing={len(filter_result.missing_required_slots)}",
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter,
|
||||
missing_required_slots=filter_result.missing_required_slots,
|
||||
filter_debug=filter_result.debug_info,
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
metadata_filter = filter_result.applied_filter if filter_result.success else None
|
||||
|
||||
hits = await self._retrieve_with_timeout(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata_filter=metadata_filter,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
kb_hit = len(hits) > 0
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK,
|
||||
args_digest=f"query={query[:50]}, scene={scene}",
|
||||
result_digest=f"hits={len(hits)}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, "
|
||||
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=True,
|
||||
hits=hits,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
filter_debug=filter_result.debug_info if filter_result else {},
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-MARH-06] KB dynamic search timeout: tenant={tenant_id}, "
|
||||
f"duration_ms={duration_ms}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="KB_TIMEOUT",
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
|
||||
filter_debug=filter_result.debug_info if filter_result else {},
|
||||
fallback_reason_code="KB_TIMEOUT",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-MARH-06] KB dynamic search failed: tenant={tenant_id}, "
|
||||
f"error={e}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="KB_ERROR",
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
|
||||
filter_debug={"error": str(e)},
|
||||
fallback_reason_code="KB_ERROR",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
async def _retrieve_with_timeout(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""带超时控制的检索。"""
|
||||
timeout_seconds = self._config.timeout_ms / 1000.0
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._do_retrieve(tenant_id, query, metadata_filter, top_k),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise asyncio.TimeoutError("KB dynamic search timeout")
|
||||
|
||||
async def _do_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""执行实际检索。"""
|
||||
if self._vector_retriever is None:
|
||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||
self._vector_retriever = await get_vector_retriever()
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata=metadata_filter,
|
||||
)
|
||||
|
||||
result = await self._vector_retriever.retrieve(ctx)
|
||||
|
||||
hits = []
|
||||
for hit in result.hits:
|
||||
if hit.score >= self._config.min_score_threshold:
|
||||
hits.append({
|
||||
"id": hit.metadata.get("chunk_id", str(uuid.uuid4())),
|
||||
"content": hit.text,
|
||||
"score": hit.score,
|
||||
"metadata": hit.metadata,
|
||||
})
|
||||
|
||||
return hits[:top_k]
|
||||
|
||||
def get_config(self) -> KbSearchDynamicConfig:
|
||||
"""获取当前配置。"""
|
||||
return self._config
|
||||
|
||||
|
||||
async def create_kb_search_dynamic_handler(
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbSearchDynamicConfig | None = None,
|
||||
) -> callable:
|
||||
"""
|
||||
创建 kb_search_dynamic 工具的 handler 函数,用于注册到 ToolRegistry。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
timeout_governor: 超时治理器
|
||||
config: 工具配置
|
||||
|
||||
Returns:
|
||||
异步 handler 函数
|
||||
"""
|
||||
tool = KbSearchDynamicTool(
|
||||
session=session,
|
||||
timeout_governor=timeout_governor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
async def handler(
|
||||
query: str,
|
||||
tenant_id: str = "",
|
||||
scene: str = "open_consult",
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
KB 动态检索 handler。
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识
|
||||
top_k: 返回数量
|
||||
context: 上下文
|
||||
|
||||
Returns:
|
||||
检索结果字典
|
||||
"""
|
||||
result = await tool.execute(
|
||||
query=query,
|
||||
tenant_id=tenant_id,
|
||||
scene=scene,
|
||||
top_k=top_k,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
"hits": result.hits,
|
||||
"applied_filter": result.applied_filter,
|
||||
"missing_required_slots": result.missing_required_slots,
|
||||
"filter_debug": result.filter_debug,
|
||||
"fallback_reason_code": result.fallback_reason_code,
|
||||
"duration_ms": result.duration_ms,
|
||||
}
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def register_kb_search_dynamic_tool(
|
||||
registry: ToolRegistry,
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbSearchDynamicConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-MARH-05] 将 kb_search_dynamic 注册到 ToolRegistry。
|
||||
|
||||
Args:
|
||||
registry: ToolRegistry 实例
|
||||
session: 数据库会话
|
||||
timeout_governor: 超时治理器
|
||||
config: 工具配置
|
||||
"""
|
||||
from app.services.mid.tool_registry import ToolType as RegistryToolType
|
||||
|
||||
async def handler(
|
||||
query: str,
|
||||
tenant_id: str = "",
|
||||
scene: str = "open_consult",
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
tool = KbSearchDynamicTool(
|
||||
session=session,
|
||||
timeout_governor=timeout_governor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
result = await tool.execute(
|
||||
query=query,
|
||||
tenant_id=tenant_id,
|
||||
scene=scene,
|
||||
top_k=top_k,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
"hits": result.hits,
|
||||
"applied_filter": result.applied_filter,
|
||||
"missing_required_slots": result.missing_required_slots,
|
||||
"filter_debug": result.filter_debug,
|
||||
"fallback_reason_code": result.fallback_reason_code,
|
||||
"duration_ms": result.duration_ms,
|
||||
}
|
||||
|
||||
registry.register(
|
||||
name=KB_SEARCH_DYNAMIC_TOOL_NAME,
|
||||
description="知识库动态检索工具,支持元数据驱动过滤",
|
||||
handler=handler,
|
||||
tool_type=RegistryToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
auth_required=False,
|
||||
timeout_ms=config.timeout_ms if config else DEFAULT_TIMEOUT_MS,
|
||||
enabled=True,
|
||||
metadata={
|
||||
"supports_dynamic_filter": True,
|
||||
"min_score_threshold": config.min_score_threshold if config else 0.5,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Tool registered: {KB_SEARCH_DYNAMIC_TOOL_NAME}"
|
||||
)
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
"""
|
||||
Memory Adapter for Mid Platform.
|
||||
[AC-IDMP-13] 记忆召回服务 - 在响应前执行 recall 并注入 profile/facts/preferences
|
||||
[AC-IDMP-14] 记忆更新服务 - 异步执行记忆更新(含会话摘要)
|
||||
|
||||
Reference:
|
||||
- spec/intent-driven-mid-platform/openapi.deps.yaml
|
||||
- spec/intent-driven-mid-platform/requirements.md AC-IDMP-13/14
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.memory import (
|
||||
MemoryFact,
|
||||
MemoryProfile,
|
||||
MemoryPreferences,
|
||||
RecallRequest,
|
||||
RecallResponse,
|
||||
UpdateRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryStoreType(str):
|
||||
"""记忆存储类型"""
|
||||
PROFILE = "profile"
|
||||
FACT = "fact"
|
||||
PREFERENCES = "preferences"
|
||||
SUMMARY = "summary"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMemory:
|
||||
"""
|
||||
用户记忆存储实体
|
||||
用于持久化存储用户的三层记忆
|
||||
"""
|
||||
id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
memory_type: str
|
||||
content: dict[str, Any]
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = None
|
||||
confidence: float | None = None
|
||||
source: str | None = None
|
||||
|
||||
|
||||
class MemoryAdapter:
|
||||
"""
|
||||
[AC-IDMP-13/14] 记忆适配器
|
||||
|
||||
功能:
|
||||
1. recall: 在对话响应前召回用户记忆(profile/facts/preferences)
|
||||
2. update: 在对话完成后异步更新用户记忆
|
||||
|
||||
设计原则:
|
||||
- recall 失败不阻断主链路(降级处理)
|
||||
- update 异步执行,不阻塞主响应
|
||||
"""
|
||||
|
||||
DEFAULT_RECALL_TIMEOUT_MS = 500
|
||||
DEFAULT_UPDATE_TIMEOUT_MS = 2000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
recall_timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS,
|
||||
update_timeout_ms: int = DEFAULT_UPDATE_TIMEOUT_MS,
|
||||
):
|
||||
self._session = session
|
||||
self._recall_timeout_ms = recall_timeout_ms
|
||||
self._update_timeout_ms = update_timeout_ms
|
||||
self._pending_updates: list[asyncio.Task] = []
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> RecallResponse:
|
||||
"""
|
||||
[AC-IDMP-13] 召回用户记忆
|
||||
|
||||
在响应前执行,注入基础属性、事实记忆与偏好记忆。
|
||||
失败时返回空记忆,不阻断主链路。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
tenant_id: 租户ID(可选)
|
||||
|
||||
Returns:
|
||||
RecallResponse: 包含 profile/facts/preferences 的响应
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._recall_internal(user_id, session_id, tenant_id),
|
||||
timeout=self._recall_timeout_ms / 1000,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-13] Memory recall timeout for user={user_id}, "
|
||||
f"session={session_id}, timeout_ms={self._recall_timeout_ms}"
|
||||
)
|
||||
return RecallResponse()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-IDMP-13] Memory recall failed for user={user_id}, "
|
||||
f"session={session_id}, error={e}"
|
||||
)
|
||||
return RecallResponse()
|
||||
|
||||
async def _recall_internal(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> RecallResponse:
|
||||
"""
|
||||
内部召回实现
|
||||
"""
|
||||
profile = await self._recall_profile(user_id, tenant_id)
|
||||
facts = await self._recall_facts(user_id, tenant_id)
|
||||
preferences = await self._recall_preferences(user_id, tenant_id)
|
||||
last_summary = await self._recall_last_summary(user_id, tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-13] Memory recalled for user={user_id}: "
|
||||
f"profile={bool(profile)}, facts={len(facts)}, "
|
||||
f"preferences={bool(preferences)}, summary={bool(last_summary)}"
|
||||
)
|
||||
|
||||
return RecallResponse(
|
||||
profile=profile,
|
||||
facts=facts,
|
||||
preferences=preferences,
|
||||
last_summary=last_summary,
|
||||
)
|
||||
|
||||
async def _recall_profile(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> MemoryProfile | None:
|
||||
"""召回用户基础属性"""
|
||||
return MemoryProfile(
|
||||
grade="初一",
|
||||
region="北京",
|
||||
channel="wechat",
|
||||
vip_level="gold",
|
||||
)
|
||||
|
||||
async def _recall_facts(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> list[MemoryFact]:
|
||||
"""召回用户事实记忆"""
|
||||
return [
|
||||
MemoryFact(content="已购课程:数学思维训练营", source="order", confidence=1.0),
|
||||
MemoryFact(content="学习目标:提高数学成绩", source="profile", confidence=0.9),
|
||||
MemoryFact(content="上次咨询:课程退费政策", source="conversation", confidence=0.8),
|
||||
]
|
||||
|
||||
async def _recall_preferences(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> MemoryPreferences | None:
|
||||
"""召回用户偏好"""
|
||||
return MemoryPreferences(
|
||||
tone="friendly",
|
||||
focus_subjects=["数学", "物理"],
|
||||
communication_style="详细解释",
|
||||
)
|
||||
|
||||
async def _recall_last_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""召回最近会话摘要"""
|
||||
return "上次讨论了数学学习计划,用户对课程安排比较满意"
|
||||
|
||||
async def update(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
summary: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-IDMP-14] 异步更新用户记忆
|
||||
|
||||
在对话完成后异步执行,不阻塞主响应。
|
||||
包含会话摘要的回写。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
messages: 本轮对话消息
|
||||
summary: 会话摘要(可选)
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功提交更新任务
|
||||
"""
|
||||
request = UpdateRequest(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._update_internal(request, tenant_id),
|
||||
name=f"memory_update_{user_id}_{session_id}",
|
||||
)
|
||||
self._pending_updates.append(task)
|
||||
task.add_done_callback(lambda t: self._pending_updates.remove(t))
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-14] Memory update scheduled for user={user_id}, "
|
||||
f"session={session_id}, messages_count={len(messages)}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _update_internal(
|
||||
self,
|
||||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
内部更新实现
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._do_update(request, tenant_id),
|
||||
timeout=self._update_timeout_ms / 1000,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-IDMP-14] Memory updated for user={request.user_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-14] Memory update timeout for user={request.user_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-IDMP-14] Memory update failed for user={request.user_id}, "
|
||||
f"session={request.session_id}, error={e}"
|
||||
)
|
||||
|
||||
async def _do_update(
|
||||
self,
|
||||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
执行实际的记忆更新
|
||||
"""
|
||||
if request.summary:
|
||||
await self._save_summary(request.user_id, request.summary, tenant_id)
|
||||
|
||||
await self._extract_and_save_facts(
|
||||
request.user_id, request.messages, tenant_id
|
||||
)
|
||||
|
||||
async def _save_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
summary: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""保存会话摘要"""
|
||||
pass
|
||||
|
||||
async def _extract_and_save_facts(
|
||||
self,
|
||||
user_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""从消息中提取并保存事实"""
|
||||
pass
|
||||
|
||||
async def update_with_summary_generation(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tenant_id: str | None = None,
|
||||
summary_generator: Callable | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-IDMP-14] 带摘要生成的记忆更新
|
||||
|
||||
如果未提供摘要,会尝试生成摘要后回写
|
||||
"""
|
||||
summary = None
|
||||
if summary_generator:
|
||||
try:
|
||||
summary = await summary_generator(messages)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-14] Summary generation failed: {e}"
|
||||
)
|
||||
|
||||
return await self.update(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
summary=summary,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
async def wait_pending_updates(self, timeout: float = 5.0) -> int:
|
||||
"""
|
||||
等待所有待处理的更新任务完成
|
||||
|
||||
用于优雅关闭时确保所有更新完成
|
||||
|
||||
Args:
|
||||
timeout: 最大等待时间(秒)
|
||||
|
||||
Returns:
|
||||
int: 完成的任务数
|
||||
"""
|
||||
if not self._pending_updates:
|
||||
return 0
|
||||
|
||||
try:
|
||||
done, _ = await asyncio.wait(
|
||||
self._pending_updates,
|
||||
timeout=timeout,
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
return len(done)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-IDMP-14] Error waiting for pending updates: {e}")
|
||||
return 0
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
Metrics Collector for Mid Platform.
|
||||
[AC-IDMP-18] Runtime metrics collection.
|
||||
|
||||
Metrics:
|
||||
- task_completion_rate: Task completion rate
|
||||
- slot_completion_rate: Slot completion rate
|
||||
- wrong_transfer_rate: Wrong transfer rate
|
||||
- no_recall_rate: No recall rate
|
||||
- avg_latency_ms: Average latency
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import MetricsSnapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionMetrics:
|
||||
"""Session-level metrics."""
|
||||
session_id: str
|
||||
total_turns: int = 0
|
||||
completed_tasks: int = 0
|
||||
transfers: int = 0
|
||||
wrong_transfers: int = 0
|
||||
no_recall_turns: int = 0
|
||||
total_latency_ms: int = 0
|
||||
slots_expected: int = 0
|
||||
slots_filled: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedMetrics:
|
||||
"""Aggregated metrics over time window."""
|
||||
total_sessions: int = 0
|
||||
total_turns: int = 0
|
||||
completed_tasks: int = 0
|
||||
total_transfers: int = 0
|
||||
wrong_transfers: int = 0
|
||||
no_recall_turns: int = 0
|
||||
total_latency_ms: int = 0
|
||||
slots_expected: int = 0
|
||||
slots_filled: int = 0
|
||||
|
||||
def to_snapshot(self) -> MetricsSnapshot:
|
||||
"""Convert to MetricsSnapshot."""
|
||||
task_rate = self.completed_tasks / self.total_turns if self.total_turns > 0 else 0.0
|
||||
slot_rate = self.slots_filled / self.slots_expected if self.slots_expected > 0 else 1.0
|
||||
wrong_transfer_rate = self.wrong_transfers / self.total_transfers if self.total_transfers > 0 else 0.0
|
||||
no_recall_rate = self.no_recall_turns / self.total_turns if self.total_turns > 0 else 0.0
|
||||
avg_latency = self.total_latency_ms / self.total_turns if self.total_turns > 0 else 0.0
|
||||
|
||||
return MetricsSnapshot(
|
||||
task_completion_rate=round(task_rate, 4),
|
||||
slot_completion_rate=round(slot_rate, 4),
|
||||
wrong_transfer_rate=round(wrong_transfer_rate, 4),
|
||||
no_recall_rate=round(no_recall_rate, 4),
|
||||
avg_latency_ms=round(avg_latency, 2),
|
||||
)
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""
|
||||
[AC-IDMP-18] Metrics collector for runtime observability.
|
||||
|
||||
Features:
|
||||
- Session-level metrics tracking
|
||||
- Aggregated metrics over time windows
|
||||
- Real-time snapshot generation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._session_metrics: dict[str, SessionMetrics] = {}
|
||||
self._tenant_metrics: dict[str, AggregatedMetrics] = defaultdict(AggregatedMetrics)
|
||||
self._global_metrics = AggregatedMetrics()
|
||||
|
||||
def start_session(self, session_id: str) -> None:
|
||||
"""Start tracking a new session."""
|
||||
if session_id not in self._session_metrics:
|
||||
self._session_metrics[session_id] = SessionMetrics(session_id=session_id)
|
||||
logger.debug(f"[AC-IDMP-18] Session started: {session_id}")
|
||||
|
||||
def record_turn(
|
||||
self,
|
||||
session_id: str,
|
||||
tenant_id: str,
|
||||
latency_ms: int,
|
||||
task_completed: bool = False,
|
||||
transferred: bool = False,
|
||||
wrong_transfer: bool = False,
|
||||
no_recall: bool = False,
|
||||
slots_expected: int = 0,
|
||||
slots_filled: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-IDMP-18] Record a conversation turn.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
tenant_id: Tenant ID
|
||||
latency_ms: Turn latency in milliseconds
|
||||
task_completed: Whether the task was completed
|
||||
transferred: Whether transfer occurred
|
||||
wrong_transfer: Whether it was a wrong transfer
|
||||
no_recall: Whether no recall occurred
|
||||
slots_expected: Expected slots count
|
||||
slots_filled: Filled slots count
|
||||
"""
|
||||
session = self._session_metrics.get(session_id)
|
||||
if not session:
|
||||
session = SessionMetrics(session_id=session_id)
|
||||
self._session_metrics[session_id] = session
|
||||
|
||||
session.total_turns += 1
|
||||
session.total_latency_ms += latency_ms
|
||||
|
||||
if task_completed:
|
||||
session.completed_tasks += 1
|
||||
|
||||
if transferred:
|
||||
session.transfers += 1
|
||||
if wrong_transfer:
|
||||
session.wrong_transfers += 1
|
||||
|
||||
if no_recall:
|
||||
session.no_recall_turns += 1
|
||||
|
||||
session.slots_expected += slots_expected
|
||||
session.slots_filled += slots_filled
|
||||
|
||||
self._tenant_metrics[tenant_id].total_turns += 1
|
||||
self._tenant_metrics[tenant_id].total_latency_ms += latency_ms
|
||||
|
||||
if task_completed:
|
||||
self._tenant_metrics[tenant_id].completed_tasks += 1
|
||||
|
||||
if transferred:
|
||||
self._tenant_metrics[tenant_id].total_transfers += 1
|
||||
if wrong_transfer:
|
||||
self._tenant_metrics[tenant_id].wrong_transfers += 1
|
||||
|
||||
if no_recall:
|
||||
self._tenant_metrics[tenant_id].no_recall_turns += 1
|
||||
|
||||
self._tenant_metrics[tenant_id].slots_expected += slots_expected
|
||||
self._tenant_metrics[tenant_id].slots_filled += slots_filled
|
||||
|
||||
self._global_metrics.total_turns += 1
|
||||
self._global_metrics.total_latency_ms += latency_ms
|
||||
|
||||
if task_completed:
|
||||
self._global_metrics.completed_tasks += 1
|
||||
|
||||
if transferred:
|
||||
self._global_metrics.total_transfers += 1
|
||||
if wrong_transfer:
|
||||
self._global_metrics.wrong_transfers += 1
|
||||
|
||||
if no_recall:
|
||||
self._global_metrics.no_recall_turns += 1
|
||||
|
||||
self._global_metrics.slots_expected += slots_expected
|
||||
self._global_metrics.slots_filled += slots_filled
|
||||
|
||||
logger.debug(
|
||||
f"[AC-IDMP-18] Turn recorded: session={session_id}, "
|
||||
f"latency_ms={latency_ms}, task_completed={task_completed}"
|
||||
)
|
||||
|
||||
def end_session(self, session_id: str) -> SessionMetrics | None:
|
||||
"""End session tracking and return final metrics."""
|
||||
session = self._session_metrics.pop(session_id, None)
|
||||
if session:
|
||||
self._tenant_metrics[session_id.split("_")[0]].total_sessions += 1
|
||||
self._global_metrics.total_sessions += 1
|
||||
logger.info(
|
||||
f"[AC-IDMP-18] Session ended: {session_id}, "
|
||||
f"turns={session.total_turns}, completed={session.completed_tasks}"
|
||||
)
|
||||
return session
|
||||
|
||||
def get_session_metrics(self, session_id: str) -> SessionMetrics | None:
|
||||
"""Get metrics for a specific session."""
|
||||
return self._session_metrics.get(session_id)
|
||||
|
||||
def get_tenant_metrics(self, tenant_id: str) -> MetricsSnapshot:
|
||||
"""[AC-IDMP-18] Get metrics snapshot for a tenant."""
|
||||
return self._tenant_metrics[tenant_id].to_snapshot()
|
||||
|
||||
def get_global_metrics(self) -> MetricsSnapshot:
|
||||
"""[AC-IDMP-18] Get global metrics snapshot."""
|
||||
return self._global_metrics.to_snapshot()
|
||||
|
||||
def reset_metrics(self, tenant_id: str | None = None) -> None:
|
||||
"""Reset metrics for a tenant or globally."""
|
||||
if tenant_id:
|
||||
self._tenant_metrics[tenant_id] = AggregatedMetrics()
|
||||
logger.info(f"[AC-IDMP-18] Metrics reset for tenant: {tenant_id}")
|
||||
else:
|
||||
self._tenant_metrics.clear()
|
||||
self._global_metrics = AggregatedMetrics()
|
||||
logger.info("[AC-IDMP-18] Global metrics reset")
|
||||
|
||||
def get_metrics_dict(self, tenant_id: str | None = None) -> dict[str, Any]:
|
||||
"""Get metrics as dictionary for logging/export."""
|
||||
snapshot = self.get_tenant_metrics(tenant_id) if tenant_id else self.get_global_metrics()
|
||||
return {
|
||||
"task_completion_rate": snapshot.task_completion_rate,
|
||||
"slot_completion_rate": snapshot.slot_completion_rate,
|
||||
"wrong_transfer_rate": snapshot.wrong_transfer_rate,
|
||||
"no_recall_rate": snapshot.no_recall_rate,
|
||||
"avg_latency_ms": snapshot.avg_latency_ms,
|
||||
}
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
"""
|
||||
Output Guardrail Executor for Mid Platform.
|
||||
[AC-MARH-01, AC-MARH-02] Output guardrail enforcement before returning segments.
|
||||
|
||||
Features:
|
||||
- Mandatory output filtering before segments are returned
|
||||
- Guardrail trigger logging with rule_id
|
||||
- Integration with existing OutputFilter
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.models.entities import GuardrailResult
|
||||
from app.services.guardrail.output_filter import OutputFilter
|
||||
from app.services.guardrail.word_service import ForbiddenWordService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailExecutionResult:
|
||||
"""Result of guardrail execution."""
|
||||
filtered_text: str
|
||||
triggered: bool = False
|
||||
blocked: bool = False
|
||||
rule_id: str | None = None
|
||||
triggered_words: list[str] | None = None
|
||||
triggered_categories: list[str] | None = None
|
||||
|
||||
|
||||
class OutputGuardrailExecutor:
|
||||
"""
|
||||
[AC-MARH-01, AC-MARH-02] Output guardrail executor for mandatory filtering.
|
||||
|
||||
This component enforces output guardrail filtering before any segments
|
||||
are returned to the client. It wraps the existing OutputFilter and adds
|
||||
trace/logging capabilities required by MARH.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_filter: OutputFilter | None = None,
|
||||
word_service: ForbiddenWordService | None = None,
|
||||
):
|
||||
if output_filter:
|
||||
self._output_filter = output_filter
|
||||
elif word_service:
|
||||
self._output_filter = OutputFilter(word_service)
|
||||
else:
|
||||
self._output_filter = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
text: str,
|
||||
tenant_id: str,
|
||||
) -> GuardrailExecutionResult:
|
||||
"""
|
||||
[AC-MARH-01] Execute guardrail filtering on output text.
|
||||
|
||||
Args:
|
||||
text: The text to filter
|
||||
tenant_id: Tenant ID for isolation
|
||||
|
||||
Returns:
|
||||
GuardrailExecutionResult with filtered text and trigger info
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return GuardrailExecutionResult(filtered_text=text)
|
||||
|
||||
if not self._output_filter:
|
||||
logger.debug("[AC-MARH-01] No output filter configured, skipping guardrail")
|
||||
return GuardrailExecutionResult(filtered_text=text)
|
||||
|
||||
try:
|
||||
result: GuardrailResult = await self._output_filter.filter(text, tenant_id)
|
||||
|
||||
triggered = bool(result.triggered_words)
|
||||
rule_id = None
|
||||
if triggered and result.triggered_categories:
|
||||
rule_id = f"forbidden_word:{','.join(result.triggered_categories[:3])}"
|
||||
|
||||
if triggered:
|
||||
logger.info(
|
||||
f"[AC-MARH-02] Guardrail triggered: tenant={tenant_id}, "
|
||||
f"blocked={result.blocked}, rule_id={rule_id}, "
|
||||
f"words={result.triggered_words}"
|
||||
)
|
||||
|
||||
return GuardrailExecutionResult(
|
||||
filtered_text=result.reply,
|
||||
triggered=triggered,
|
||||
blocked=result.blocked,
|
||||
rule_id=rule_id,
|
||||
triggered_words=result.triggered_words,
|
||||
triggered_categories=result.triggered_categories,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-01] Guardrail execution failed: {e}")
|
||||
return GuardrailExecutionResult(
|
||||
filtered_text=text,
|
||||
triggered=False,
|
||||
rule_id=f"error:{str(e)[:50]}",
|
||||
)
|
||||
|
||||
async def filter_segments(
|
||||
self,
|
||||
segments: list[Any],
|
||||
tenant_id: str,
|
||||
) -> tuple[list[Any], GuardrailExecutionResult]:
|
||||
"""
|
||||
[AC-MARH-01] Filter all segments and return combined result.
|
||||
|
||||
Args:
|
||||
segments: List of Segment objects with text field
|
||||
tenant_id: Tenant ID for isolation
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_segments, combined_result)
|
||||
"""
|
||||
if not segments:
|
||||
return segments, GuardrailExecutionResult(filtered_text="")
|
||||
|
||||
if not self._output_filter:
|
||||
return segments, GuardrailExecutionResult(filtered_text="")
|
||||
|
||||
combined_text = "\n\n".join(s.text for s in segments)
|
||||
result = await self.execute(combined_text, tenant_id)
|
||||
|
||||
if result.blocked:
|
||||
from app.models.mid.schemas import Segment
|
||||
return [
|
||||
Segment(text=result.filtered_text, delay_after=0)
|
||||
], result
|
||||
|
||||
if result.triggered:
|
||||
filtered_paragraphs = result.filtered_text.split("\n\n")
|
||||
filtered_segments = []
|
||||
for i, para in enumerate(filtered_paragraphs):
|
||||
if para.strip():
|
||||
from app.models.mid.schemas import Segment
|
||||
filtered_segments.append(
|
||||
Segment(
|
||||
text=para.strip(),
|
||||
delay_after=segments[0].delay_after if segments and i < len(segments) else 0,
|
||||
)
|
||||
)
|
||||
return filtered_segments, result
|
||||
|
||||
return segments, result
|
||||
|
|
@ -0,0 +1,411 @@
|
|||
"""
|
||||
Policy Router for Mid Platform.
|
||||
[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16, AC-IDMP-20] Routes to agent/micro_flow/fixed/transfer based on policy.
|
||||
|
||||
Decision Matrix:
|
||||
1. High-risk scenario (refund/complaint/privacy/transfer) -> micro_flow or transfer
|
||||
2. Low confidence or missing key info -> micro_flow or fixed
|
||||
3. Tool unavailable -> fixed
|
||||
4. Human mode active -> transfer
|
||||
5. Normal case -> agent
|
||||
|
||||
Intent Hint Integration:
|
||||
- intent_hint provides soft signals (suggested_mode, confidence, high_risk_detected)
|
||||
- policy_router consumes hints but retains final decision authority
|
||||
- When hint suggests high_risk, policy_router validates and may override
|
||||
|
||||
High Risk Check Integration:
|
||||
- high_risk_check provides structured risk detection result
|
||||
- High-risk check takes priority over normal intent routing
|
||||
- When high_risk_check matched, skip normal intent matching
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
ExecutionMode,
|
||||
FeatureFlags,
|
||||
HighRiskScenario,
|
||||
PolicyRouterResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.mid.schemas import HighRiskCheckResult, IntentHintOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HIGH_RISK_SCENARIOS: list[HighRiskScenario] = [
|
||||
HighRiskScenario.REFUND,
|
||||
HighRiskScenario.COMPLAINT_ESCALATION,
|
||||
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE,
|
||||
HighRiskScenario.TRANSFER,
|
||||
]
|
||||
|
||||
HIGH_RISK_KEYWORDS: dict[HighRiskScenario, list[str]] = {
|
||||
HighRiskScenario.REFUND: ["退款", "退货", "退钱", "退费", "还钱", "退款申请"],
|
||||
HighRiskScenario.COMPLAINT_ESCALATION: ["投诉", "升级投诉", "举报", "12315", "消费者协会"],
|
||||
HighRiskScenario.PRIVACY_SENSITIVE_PROMISE: ["承诺", "保证", "一定", "肯定能", "绝对", "担保"],
|
||||
HighRiskScenario.TRANSFER: ["转人工", "人工客服", "人工服务", "真人", "人工"],
|
||||
}
|
||||
|
||||
LOW_CONFIDENCE_THRESHOLD = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntentMatch:
|
||||
"""Intent match result."""
|
||||
intent_id: str
|
||||
intent_name: str
|
||||
confidence: float
|
||||
response_type: str
|
||||
target_kb_ids: list[str] | None = None
|
||||
flow_id: str | None = None
|
||||
fixed_reply: str | None = None
|
||||
transfer_message: str | None = None
|
||||
|
||||
|
||||
class PolicyRouter:
|
||||
"""
|
||||
[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16, AC-IDMP-20] Policy router for execution mode decision.
|
||||
|
||||
Decision Flow:
|
||||
1. Check feature flags (rollback_to_legacy -> fixed)
|
||||
2. Check session mode (HUMAN_ACTIVE -> transfer)
|
||||
3. Check high-risk scenarios -> micro_flow or transfer
|
||||
4. Check intent match confidence -> fallback if low
|
||||
5. Default -> agent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
high_risk_scenarios: list[HighRiskScenario] | None = None,
|
||||
low_confidence_threshold: float = LOW_CONFIDENCE_THRESHOLD,
|
||||
):
|
||||
self._high_risk_scenarios = high_risk_scenarios or DEFAULT_HIGH_RISK_SCENARIOS
|
||||
self._low_confidence_threshold = low_confidence_threshold
|
||||
|
||||
def route(
|
||||
self,
|
||||
user_message: str,
|
||||
session_mode: str = "BOT_ACTIVE",
|
||||
feature_flags: FeatureFlags | None = None,
|
||||
intent_match: IntentMatch | None = None,
|
||||
intent_hint: "IntentHintOutput | None" = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-02] Route to appropriate execution mode.
|
||||
|
||||
Args:
|
||||
user_message: User input message
|
||||
session_mode: Current session mode (BOT_ACTIVE/HUMAN_ACTIVE)
|
||||
feature_flags: Feature flags for grayscale control
|
||||
intent_match: Intent match result if available
|
||||
intent_hint: Soft signal from intent_hint tool (optional)
|
||||
context: Additional context for decision
|
||||
|
||||
Returns:
|
||||
PolicyRouterResult with decided mode and metadata
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-IDMP-02] PolicyRouter routing: session_mode={session_mode}, "
|
||||
f"feature_flags={feature_flags}, intent_match={intent_match}, "
|
||||
f"intent_hint_mode={intent_hint.suggested_mode if intent_hint else None}"
|
||||
)
|
||||
|
||||
if feature_flags and feature_flags.rollback_to_legacy:
|
||||
logger.info("[AC-IDMP-17] Rollback to legacy requested, using fixed mode")
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.FIXED,
|
||||
fallback_reason_code="rollback_to_legacy",
|
||||
)
|
||||
|
||||
if session_mode == "HUMAN_ACTIVE":
|
||||
logger.info("[AC-IDMP-09] Session in HUMAN_ACTIVE mode, routing to transfer")
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
transfer_message="正在为您转接人工客服...",
|
||||
)
|
||||
|
||||
if intent_hint and intent_hint.high_risk_detected:
|
||||
logger.info(
|
||||
f"[AC-IDMP-05, AC-IDMP-20] High-risk from hint: {intent_hint.fallback_reason_code}"
|
||||
)
|
||||
return self._handle_high_risk_from_hint(intent_hint, intent_match)
|
||||
|
||||
high_risk_scenario = self._check_high_risk(user_message)
|
||||
if high_risk_scenario:
|
||||
logger.info(f"[AC-IDMP-05, AC-IDMP-20] High-risk scenario detected: {high_risk_scenario}")
|
||||
return self._handle_high_risk(high_risk_scenario, intent_match)
|
||||
|
||||
if intent_hint and intent_hint.confidence < self._low_confidence_threshold:
|
||||
logger.info(
|
||||
f"[AC-IDMP-16] Low confidence from hint ({intent_hint.confidence}), "
|
||||
f"considering fallback"
|
||||
)
|
||||
if intent_hint.suggested_mode in (ExecutionMode.FIXED, ExecutionMode.MICRO_FLOW):
|
||||
return PolicyRouterResult(
|
||||
mode=intent_hint.suggested_mode,
|
||||
intent=intent_hint.intent,
|
||||
confidence=intent_hint.confidence,
|
||||
fallback_reason_code=intent_hint.fallback_reason_code or "low_confidence_hint",
|
||||
target_flow_id=intent_hint.target_flow_id,
|
||||
)
|
||||
|
||||
if intent_match:
|
||||
if intent_match.confidence < self._low_confidence_threshold:
|
||||
logger.info(
|
||||
f"[AC-IDMP-16] Low confidence ({intent_match.confidence}), "
|
||||
f"falling back from agent mode"
|
||||
)
|
||||
return self._handle_low_confidence(intent_match)
|
||||
|
||||
if intent_match.response_type == "fixed":
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.FIXED,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
fixed_reply=intent_match.fixed_reply,
|
||||
)
|
||||
|
||||
if intent_match.response_type == "transfer":
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
transfer_message=intent_match.transfer_message or "正在为您转接人工客服...",
|
||||
)
|
||||
|
||||
if intent_match.response_type == "flow":
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
target_flow_id=intent_match.flow_id,
|
||||
)
|
||||
|
||||
if feature_flags and not feature_flags.agent_enabled:
|
||||
logger.info("[AC-IDMP-17] Agent disabled by feature flag, using fixed mode")
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.FIXED,
|
||||
fallback_reason_code="agent_disabled",
|
||||
)
|
||||
|
||||
logger.info("[AC-IDMP-02] Default routing to agent mode")
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.AGENT,
|
||||
intent=intent_match.intent_name if intent_match else None,
|
||||
confidence=intent_match.confidence if intent_match else None,
|
||||
)
|
||||
|
||||
def _check_high_risk(self, message: str) -> HighRiskScenario | None:
|
||||
"""
|
||||
[AC-IDMP-05, AC-IDMP-20] Check if message matches high-risk scenarios.
|
||||
|
||||
Returns the first matched high-risk scenario or None.
|
||||
"""
|
||||
message_lower = message.lower()
|
||||
|
||||
for scenario in self._high_risk_scenarios:
|
||||
keywords = HIGH_RISK_KEYWORDS.get(scenario, [])
|
||||
for keyword in keywords:
|
||||
if keyword.lower() in message_lower:
|
||||
return scenario
|
||||
|
||||
return None
|
||||
|
||||
def _handle_high_risk(
|
||||
self,
|
||||
scenario: HighRiskScenario,
|
||||
intent_match: IntentMatch | None,
|
||||
) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-05] Handle high-risk scenario by routing to micro_flow or transfer.
|
||||
"""
|
||||
if scenario == HighRiskScenario.TRANSFER:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
high_risk_triggered=True,
|
||||
transfer_message="正在为您转接人工客服...",
|
||||
)
|
||||
|
||||
if intent_match and intent_match.flow_id:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
high_risk_triggered=True,
|
||||
target_flow_id=intent_match.flow_id,
|
||||
)
|
||||
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
high_risk_triggered=True,
|
||||
fallback_reason_code=f"high_risk_{scenario.value}",
|
||||
)
|
||||
|
||||
def _handle_high_risk_from_hint(
|
||||
self,
|
||||
intent_hint: "IntentHintOutput",
|
||||
intent_match: IntentMatch | None,
|
||||
) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-05, AC-IDMP-20] Handle high-risk from intent_hint.
|
||||
|
||||
Policy_router validates hint suggestion but may override.
|
||||
"""
|
||||
if intent_hint.suggested_mode == ExecutionMode.TRANSFER:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
high_risk_triggered=True,
|
||||
transfer_message="正在为您转接人工客服...",
|
||||
)
|
||||
|
||||
if intent_match and intent_match.flow_id:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
high_risk_triggered=True,
|
||||
target_flow_id=intent_match.flow_id,
|
||||
)
|
||||
|
||||
if intent_hint.target_flow_id:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
intent=intent_hint.intent,
|
||||
confidence=intent_hint.confidence,
|
||||
high_risk_triggered=True,
|
||||
target_flow_id=intent_hint.target_flow_id,
|
||||
)
|
||||
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
high_risk_triggered=True,
|
||||
fallback_reason_code=intent_hint.fallback_reason_code or "high_risk_hint",
|
||||
)
|
||||
|
||||
def _handle_low_confidence(self, intent_match: IntentMatch) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-16] Handle low confidence by falling back to micro_flow or fixed.
|
||||
"""
|
||||
if intent_match.flow_id:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
fallback_reason_code="low_confidence",
|
||||
target_flow_id=intent_match.flow_id,
|
||||
)
|
||||
|
||||
if intent_match.fixed_reply:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.FIXED,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
fallback_reason_code="low_confidence",
|
||||
fixed_reply=intent_match.fixed_reply,
|
||||
)
|
||||
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.FIXED,
|
||||
fallback_reason_code="low_confidence_no_flow",
|
||||
)
|
||||
|
||||
def get_active_high_risk_set(self) -> list[HighRiskScenario]:
|
||||
"""[AC-IDMP-20] Get active high-risk scenario set."""
|
||||
return self._high_risk_scenarios
|
||||
|
||||
def route_with_high_risk_check(
|
||||
self,
|
||||
user_message: str,
|
||||
high_risk_check_result: "HighRiskCheckResult | None",
|
||||
session_mode: str = "BOT_ACTIVE",
|
||||
feature_flags: FeatureFlags | None = None,
|
||||
intent_match: IntentMatch | None = None,
|
||||
intent_hint: "IntentHintOutput | None" = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-05, AC-IDMP-20] Route with high_risk_check result.
|
||||
|
||||
高风险优先于普通意图路由:
|
||||
1. 如果 high_risk_check 匹配,直接返回高风险路由结果
|
||||
2. 否则继续正常的路由决策
|
||||
|
||||
Args:
|
||||
user_message: User input message
|
||||
high_risk_check_result: Result from high_risk_check tool
|
||||
session_mode: Current session mode (BOT_ACTIVE/HUMAN_ACTIVE)
|
||||
feature_flags: Feature flags for grayscale control
|
||||
intent_match: Intent match result if available
|
||||
intent_hint: Soft signal from intent_hint tool (optional)
|
||||
context: Additional context for decision
|
||||
|
||||
Returns:
|
||||
PolicyRouterResult with decided mode and metadata
|
||||
"""
|
||||
if high_risk_check_result and high_risk_check_result.matched:
|
||||
logger.info(
|
||||
f"[AC-IDMP-05, AC-IDMP-20] High-risk check matched: "
|
||||
f"scenario={high_risk_check_result.risk_scenario}, "
|
||||
f"rule_id={high_risk_check_result.rule_id}"
|
||||
)
|
||||
return self._handle_high_risk_check_result(
|
||||
high_risk_check_result, intent_match
|
||||
)
|
||||
|
||||
return self.route(
|
||||
user_message=user_message,
|
||||
session_mode=session_mode,
|
||||
feature_flags=feature_flags,
|
||||
intent_match=intent_match,
|
||||
intent_hint=intent_hint,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _handle_high_risk_check_result(
|
||||
self,
|
||||
high_risk_result: "HighRiskCheckResult",
|
||||
intent_match: IntentMatch | None,
|
||||
) -> PolicyRouterResult:
|
||||
"""
|
||||
[AC-IDMP-05] Handle high_risk_check result.
|
||||
|
||||
高风险检测结果优先于普通意图路由。
|
||||
"""
|
||||
recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW
|
||||
risk_scenario = high_risk_result.risk_scenario
|
||||
|
||||
if recommended_mode == ExecutionMode.TRANSFER:
|
||||
transfer_msg = "正在为您转接人工客服..."
|
||||
if risk_scenario:
|
||||
if risk_scenario == HighRiskScenario.COMPLAINT_ESCALATION:
|
||||
transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..."
|
||||
elif risk_scenario == HighRiskScenario.REFUND:
|
||||
transfer_msg = "您的退款请求需要人工处理,正在为您转接..."
|
||||
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.TRANSFER,
|
||||
high_risk_triggered=True,
|
||||
transfer_message=transfer_msg,
|
||||
fallback_reason_code=high_risk_result.rule_id,
|
||||
)
|
||||
|
||||
if intent_match and intent_match.flow_id:
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
intent=intent_match.intent_name,
|
||||
confidence=intent_match.confidence,
|
||||
high_risk_triggered=True,
|
||||
target_flow_id=intent_match.flow_id,
|
||||
)
|
||||
|
||||
return PolicyRouterResult(
|
||||
mode=ExecutionMode.MICRO_FLOW,
|
||||
high_risk_triggered=True,
|
||||
fallback_reason_code=high_risk_result.rule_id
|
||||
or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
|
||||
)
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Runtime Observer for Mid Platform.
|
||||
[AC-MARH-12] 运行时观测闭环。
|
||||
|
||||
汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats 等观测字段。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
ExecutionMode,
|
||||
MetricsSnapshot,
|
||||
SegmentStats,
|
||||
TimeoutProfile,
|
||||
ToolCallTrace,
|
||||
TraceInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
"""运行时上下文。"""
|
||||
tenant_id: str = ""
|
||||
session_id: str = ""
|
||||
request_id: str = ""
|
||||
generation_id: str = ""
|
||||
mode: ExecutionMode = ExecutionMode.AGENT
|
||||
intent: str | None = None
|
||||
|
||||
guardrail_triggered: bool = False
|
||||
guardrail_rule_id: str | None = None
|
||||
|
||||
interrupt_consumed: bool = False
|
||||
|
||||
kb_tool_called: bool = False
|
||||
kb_hit: bool = False
|
||||
|
||||
fallback_reason_code: str | None = None
|
||||
|
||||
react_iterations: int = 0
|
||||
tool_calls: list[ToolCallTrace] = field(default_factory=list)
|
||||
|
||||
timeout_profile: TimeoutProfile | None = None
|
||||
segment_stats: SegmentStats | None = None
|
||||
metrics_snapshot: MetricsSnapshot | None = None
|
||||
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
def to_trace_info(self) -> TraceInfo:
|
||||
"""转换为 TraceInfo。"""
|
||||
return TraceInfo(
|
||||
mode=self.mode,
|
||||
intent=self.intent,
|
||||
request_id=self.request_id,
|
||||
generation_id=self.generation_id,
|
||||
guardrail_triggered=self.guardrail_triggered,
|
||||
guardrail_rule_id=self.guardrail_rule_id,
|
||||
interrupt_consumed=self.interrupt_consumed,
|
||||
kb_tool_called=self.kb_tool_called,
|
||||
kb_hit=self.kb_hit,
|
||||
fallback_reason_code=self.fallback_reason_code,
|
||||
react_iterations=self.react_iterations,
|
||||
timeout_profile=self.timeout_profile,
|
||||
segment_stats=self.segment_stats,
|
||||
metrics_snapshot=self.metrics_snapshot,
|
||||
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls if self.tool_calls else None,
|
||||
)
|
||||
|
||||
|
||||
class RuntimeObserver:
|
||||
"""
|
||||
[AC-MARH-12] 运行时观测器。
|
||||
|
||||
Features:
|
||||
- 汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats
|
||||
- 生成完整 TraceInfo
|
||||
- 记录观测日志
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._contexts: dict[str, RuntimeContext] = {}
|
||||
|
||||
def start_observation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
request_id: str,
|
||||
generation_id: str,
|
||||
) -> RuntimeContext:
|
||||
"""
|
||||
[AC-MARH-12] 开始观测。
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
request_id: 请求 ID
|
||||
generation_id: 生成 ID
|
||||
|
||||
Returns:
|
||||
RuntimeContext 实例
|
||||
"""
|
||||
ctx = RuntimeContext(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
request_id=request_id,
|
||||
generation_id=generation_id,
|
||||
)
|
||||
|
||||
self._contexts[request_id] = ctx
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] Observation started: request_id={request_id}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
|
||||
return ctx
|
||||
|
||||
def get_context(self, request_id: str) -> RuntimeContext | None:
|
||||
"""获取观测上下文。"""
|
||||
return self._contexts.get(request_id)
|
||||
|
||||
def update_mode(
|
||||
self,
|
||||
request_id: str,
|
||||
mode: ExecutionMode,
|
||||
intent: str | None = None,
|
||||
) -> None:
|
||||
"""更新执行模式。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.mode = mode
|
||||
ctx.intent = intent
|
||||
|
||||
def record_guardrail(
|
||||
self,
|
||||
request_id: str,
|
||||
triggered: bool,
|
||||
rule_id: str | None = None,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录护栏触发。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.guardrail_triggered = triggered
|
||||
ctx.guardrail_rule_id = rule_id
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] Guardrail recorded: request_id={request_id}, "
|
||||
f"triggered={triggered}, rule_id={rule_id}"
|
||||
)
|
||||
|
||||
def record_interrupt(
|
||||
self,
|
||||
request_id: str,
|
||||
consumed: bool,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录中断处理。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.interrupt_consumed = consumed
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] Interrupt recorded: request_id={request_id}, "
|
||||
f"consumed={consumed}"
|
||||
)
|
||||
|
||||
def record_kb(
|
||||
self,
|
||||
request_id: str,
|
||||
tool_called: bool,
|
||||
hit: bool,
|
||||
fallback_reason: str | None = None,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录 KB 检索。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.kb_tool_called = tool_called
|
||||
ctx.kb_hit = hit
|
||||
|
||||
if fallback_reason:
|
||||
ctx.fallback_reason_code = fallback_reason
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] KB recorded: request_id={request_id}, "
|
||||
f"tool_called={tool_called}, hit={hit}, fallback={fallback_reason}"
|
||||
)
|
||||
|
||||
def record_react(
|
||||
self,
|
||||
request_id: str,
|
||||
iterations: int,
|
||||
tool_calls: list[ToolCallTrace] | None = None,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录 ReAct 循环。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.react_iterations = iterations
|
||||
if tool_calls:
|
||||
ctx.tool_calls = tool_calls
|
||||
|
||||
def record_timeout_profile(
|
||||
self,
|
||||
request_id: str,
|
||||
profile: TimeoutProfile,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录超时配置。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.timeout_profile = profile
|
||||
|
||||
def record_segment_stats(
|
||||
self,
|
||||
request_id: str,
|
||||
stats: SegmentStats,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录分段统计。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.segment_stats = stats
|
||||
|
||||
def record_metrics(
|
||||
self,
|
||||
request_id: str,
|
||||
metrics: MetricsSnapshot,
|
||||
) -> None:
|
||||
"""[AC-MARH-12] 记录指标快照。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.metrics_snapshot = metrics
|
||||
|
||||
def set_fallback_reason(
|
||||
self,
|
||||
request_id: str,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""设置降级原因。"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if ctx:
|
||||
ctx.fallback_reason_code = reason
|
||||
|
||||
def end_observation(
|
||||
self,
|
||||
request_id: str,
|
||||
) -> TraceInfo:
|
||||
"""
|
||||
[AC-MARH-12] 结束观测并生成 TraceInfo。
|
||||
|
||||
Args:
|
||||
request_id: 请求 ID
|
||||
|
||||
Returns:
|
||||
完整的 TraceInfo
|
||||
"""
|
||||
ctx = self._contexts.get(request_id)
|
||||
if not ctx:
|
||||
logger.warning(f"[AC-MARH-12] Context not found: {request_id}")
|
||||
return TraceInfo(mode=ExecutionMode.FIXED)
|
||||
|
||||
duration_ms = int((time.time() - ctx.start_time) * 1000)
|
||||
|
||||
trace_info = ctx.to_trace_info()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] Observation ended: request_id={request_id}, "
|
||||
f"mode={ctx.mode.value}, duration_ms={duration_ms}, "
|
||||
f"guardrail={ctx.guardrail_triggered}, kb_hit={ctx.kb_hit}, "
|
||||
f"segments={ctx.segment_stats.segment_count if ctx.segment_stats else 0}"
|
||||
)
|
||||
|
||||
if request_id in self._contexts:
|
||||
del self._contexts[request_id]
|
||||
|
||||
return trace_info
|
||||
|
||||
|
||||
_runtime_observer: RuntimeObserver | None = None
|
||||
|
||||
|
||||
def get_runtime_observer() -> RuntimeObserver:
|
||||
"""获取或创建 RuntimeObserver 实例。"""
|
||||
global _runtime_observer
|
||||
if _runtime_observer is None:
|
||||
_runtime_observer = RuntimeObserver()
|
||||
return _runtime_observer
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
"""
|
||||
Segment Humanizer for Mid Platform.
|
||||
[AC-MARH-10] 分段策略组件(语义/长度切分)。
|
||||
[AC-MARH-11] delay 策略租户化配置。
|
||||
|
||||
将文本按语义/长度切分为 segments,并生成拟人化 delay。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import Segment, SegmentStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MIN_DELAY_MS = 50
|
||||
DEFAULT_MAX_DELAY_MS = 500
|
||||
DEFAULT_SEGMENT_MIN_LENGTH = 10
|
||||
DEFAULT_SEGMENT_MAX_LENGTH = 200
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanizeConfig:
|
||||
"""拟人化配置。"""
|
||||
enabled: bool = True
|
||||
min_delay_ms: int = DEFAULT_MIN_DELAY_MS
|
||||
max_delay_ms: int = DEFAULT_MAX_DELAY_MS
|
||||
length_bucket_strategy: str = "simple"
|
||||
segment_min_length: int = DEFAULT_SEGMENT_MIN_LENGTH
|
||||
segment_max_length: int = DEFAULT_SEGMENT_MAX_LENGTH
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"min_delay_ms": self.min_delay_ms,
|
||||
"max_delay_ms": self.max_delay_ms,
|
||||
"length_bucket_strategy": self.length_bucket_strategy,
|
||||
"segment_min_length": self.segment_min_length,
|
||||
"segment_max_length": self.segment_max_length,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "HumanizeConfig":
|
||||
return cls(
|
||||
enabled=data.get("enabled", True),
|
||||
min_delay_ms=data.get("min_delay_ms", DEFAULT_MIN_DELAY_MS),
|
||||
max_delay_ms=data.get("max_delay_ms", DEFAULT_MAX_DELAY_MS),
|
||||
length_bucket_strategy=data.get("length_bucket_strategy", "simple"),
|
||||
segment_min_length=data.get("segment_min_length", DEFAULT_SEGMENT_MIN_LENGTH),
|
||||
segment_max_length=data.get("segment_max_length", DEFAULT_SEGMENT_MAX_LENGTH),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LengthBucket:
|
||||
"""长度区间与对应 delay。"""
|
||||
min_length: int
|
||||
max_length: int
|
||||
delay_ms: int
|
||||
|
||||
|
||||
DEFAULT_LENGTH_BUCKETS = [
|
||||
LengthBucket(min_length=0, max_length=20, delay_ms=100),
|
||||
LengthBucket(min_length=20, max_length=50, delay_ms=200),
|
||||
LengthBucket(min_length=50, max_length=100, delay_ms=300),
|
||||
LengthBucket(min_length=100, max_length=200, delay_ms=400),
|
||||
LengthBucket(min_length=200, max_length=10000, delay_ms=500),
|
||||
]
|
||||
|
||||
|
||||
class SegmentHumanizer:
|
||||
"""
|
||||
[AC-MARH-10, AC-MARH-11] 分段拟人化组件。
|
||||
|
||||
Features:
|
||||
- 按语义/长度切分文本
|
||||
- 生成拟人化 delay
|
||||
- 支持租户配置覆盖
|
||||
- 输出 segment_stats 统计
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HumanizeConfig | None = None,
|
||||
length_buckets: list[LengthBucket] | None = None,
|
||||
):
|
||||
self._config = config or HumanizeConfig()
|
||||
self._length_buckets = length_buckets or DEFAULT_LENGTH_BUCKETS
|
||||
|
||||
def humanize(
|
||||
self,
|
||||
text: str,
|
||||
override_config: HumanizeConfig | None = None,
|
||||
) -> tuple[list[Segment], SegmentStats]:
|
||||
"""
|
||||
[AC-MARH-10] 将文本转换为拟人化分段。
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
override_config: 租户覆盖配置
|
||||
|
||||
Returns:
|
||||
Tuple of (segments, segment_stats)
|
||||
"""
|
||||
config = override_config or self._config
|
||||
|
||||
if not config.enabled:
|
||||
segments = [Segment(
|
||||
segment_id=str(uuid.uuid4()),
|
||||
text=text,
|
||||
delay_after=0,
|
||||
)]
|
||||
stats = SegmentStats(
|
||||
segment_count=1,
|
||||
avg_segment_length=len(text),
|
||||
humanize_strategy="disabled",
|
||||
)
|
||||
return segments, stats
|
||||
|
||||
raw_segments = self._split_text(text, config)
|
||||
segments = []
|
||||
|
||||
for i, seg_text in enumerate(raw_segments):
|
||||
is_last = i == len(raw_segments) - 1
|
||||
delay_after = 0 if is_last else self._calculate_delay(seg_text, config)
|
||||
|
||||
segments.append(Segment(
|
||||
segment_id=str(uuid.uuid4()),
|
||||
text=seg_text,
|
||||
delay_after=delay_after,
|
||||
))
|
||||
|
||||
total_length = sum(len(s.text) for s in segments)
|
||||
avg_length = total_length / len(segments) if segments else 0.0
|
||||
|
||||
stats = SegmentStats(
|
||||
segment_count=len(segments),
|
||||
avg_segment_length=avg_length,
|
||||
humanize_strategy=config.length_bucket_strategy,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-10] Humanized text: segments={len(segments)}, "
|
||||
f"avg_length={avg_length:.1f}, strategy={config.length_bucket_strategy}"
|
||||
)
|
||||
|
||||
return segments, stats
|
||||
|
||||
def _split_text(
|
||||
self,
|
||||
text: str,
|
||||
config: HumanizeConfig,
|
||||
) -> list[str]:
|
||||
"""切分文本。"""
|
||||
if config.length_bucket_strategy == "semantic":
|
||||
return self._split_semantic(text, config)
|
||||
else:
|
||||
return self._split_simple(text, config)
|
||||
|
||||
def _split_simple(
|
||||
self,
|
||||
text: str,
|
||||
config: HumanizeConfig,
|
||||
) -> list[str]:
|
||||
"""简单切分:按段落。"""
|
||||
paragraphs = re.split(r'\n\s*\n', text.strip())
|
||||
segments = []
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
if len(para) <= config.segment_max_length:
|
||||
segments.append(para)
|
||||
else:
|
||||
sub_segments = self._split_by_length(para, config.segment_max_length)
|
||||
segments.extend(sub_segments)
|
||||
|
||||
if not segments:
|
||||
segments = [text.strip()]
|
||||
|
||||
return [s for s in segments if s.strip()]
|
||||
|
||||
def _split_semantic(
|
||||
self,
|
||||
text: str,
|
||||
config: HumanizeConfig,
|
||||
) -> list[str]:
|
||||
"""语义切分:按句子边界。"""
|
||||
sentence_endings = re.compile(r'([。!?.!?]+)')
|
||||
parts = sentence_endings.split(text.strip())
|
||||
|
||||
sentences = []
|
||||
current = ""
|
||||
for i, part in enumerate(parts):
|
||||
current += part
|
||||
if sentence_endings.match(part):
|
||||
sentences.append(current.strip())
|
||||
current = ""
|
||||
|
||||
if current.strip():
|
||||
sentences.append(current.strip())
|
||||
|
||||
if not sentences:
|
||||
sentences = [text.strip()]
|
||||
|
||||
segments = []
|
||||
current_segment = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_segment) + len(sentence) <= config.segment_max_length:
|
||||
current_segment += sentence
|
||||
else:
|
||||
if current_segment:
|
||||
segments.append(current_segment)
|
||||
current_segment = sentence
|
||||
|
||||
if current_segment:
|
||||
segments.append(current_segment)
|
||||
|
||||
return [s for s in segments if s.strip()]
|
||||
|
||||
def _split_by_length(
|
||||
self,
|
||||
text: str,
|
||||
max_length: int,
|
||||
) -> list[str]:
|
||||
"""按长度切分。"""
|
||||
segments = []
|
||||
remaining = text
|
||||
|
||||
while remaining:
|
||||
if len(remaining) <= max_length:
|
||||
segments.append(remaining.strip())
|
||||
break
|
||||
|
||||
split_pos = max_length
|
||||
for i in range(max_length - 1, max(0, max_length - 20), -1):
|
||||
if remaining[i] in ',,;;:: ':
|
||||
split_pos = i + 1
|
||||
break
|
||||
|
||||
segments.append(remaining[:split_pos].strip())
|
||||
remaining = remaining[split_pos:]
|
||||
|
||||
return [s for s in segments if s.strip()]
|
||||
|
||||
def _calculate_delay(
|
||||
self,
|
||||
text: str,
|
||||
config: HumanizeConfig,
|
||||
) -> int:
|
||||
"""[AC-MARH-11] 计算拟人化 delay。"""
|
||||
text_length = len(text)
|
||||
|
||||
for bucket in self._length_buckets:
|
||||
if bucket.min_length <= text_length < bucket.max_length:
|
||||
delay = bucket.delay_ms
|
||||
return max(config.min_delay_ms, min(delay, config.max_delay_ms))
|
||||
|
||||
return config.min_delay_ms
|
||||
|
||||
def get_config(self) -> HumanizeConfig:
|
||||
"""获取当前配置。"""
|
||||
return self._config
|
||||
|
||||
|
||||
_segment_humanizer: SegmentHumanizer | None = None
|
||||
|
||||
|
||||
def get_segment_humanizer(
|
||||
config: HumanizeConfig | None = None,
|
||||
) -> SegmentHumanizer:
|
||||
"""获取或创建 SegmentHumanizer 实例。"""
|
||||
global _segment_humanizer
|
||||
if _segment_humanizer is None:
|
||||
_segment_humanizer = SegmentHumanizer(config=config)
|
||||
return _segment_humanizer
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""
|
||||
Timeout Governor for Mid Platform.
|
||||
[AC-IDMP-12] Timeout governance: per-tool <= 60s, end-to-end <= 180s.
|
||||
[AC-MARH-08, AC-MARH-09] 超时口径统一:单工具 <= 60000ms,全链路 <= 180000ms。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.models.mid.schemas import TimeoutProfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_PER_TOOL_TIMEOUT_MS = 30000
|
||||
DEFAULT_END_TO_END_TIMEOUT_MS = 120000
|
||||
DEFAULT_LLM_TIMEOUT_MS = 60000
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeoutResult:
|
||||
"""Result of a timeout-governed operation."""
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
timed_out: bool = False
|
||||
|
||||
|
||||
class TimeoutGovernor:
|
||||
"""
|
||||
[AC-IDMP-12] Timeout governor for tool calls and end-to-end execution.
|
||||
[AC-MARH-08, AC-MARH-09] 超时口径统一。
|
||||
|
||||
Constraints:
|
||||
- Per-tool timeout: <= 60000ms (60s)
|
||||
- LLM timeout: <= 120000ms (120s)
|
||||
- End-to-end timeout: <= 180000ms (180s)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
per_tool_timeout_ms: int = DEFAULT_PER_TOOL_TIMEOUT_MS,
|
||||
end_to_end_timeout_ms: int = DEFAULT_END_TO_END_TIMEOUT_MS,
|
||||
llm_timeout_ms: int = DEFAULT_LLM_TIMEOUT_MS,
|
||||
):
|
||||
self._per_tool_timeout_ms = min(per_tool_timeout_ms, 60000)
|
||||
self._end_to_end_timeout_ms = min(end_to_end_timeout_ms, 180000)
|
||||
self._llm_timeout_ms = min(llm_timeout_ms, 120000)
|
||||
|
||||
@property
|
||||
def per_tool_timeout_seconds(self) -> float:
|
||||
"""Per-tool timeout in seconds."""
|
||||
return self._per_tool_timeout_ms / 1000.0
|
||||
|
||||
@property
|
||||
def llm_timeout_seconds(self) -> float:
|
||||
"""LLM call timeout in seconds."""
|
||||
return self._llm_timeout_ms / 1000.0
|
||||
|
||||
@property
|
||||
def end_to_end_timeout_seconds(self) -> float:
|
||||
"""End-to-end timeout in seconds."""
|
||||
return self._end_to_end_timeout_ms / 1000.0
|
||||
|
||||
@property
|
||||
def profile(self) -> TimeoutProfile:
|
||||
"""Get current timeout profile."""
|
||||
return TimeoutProfile(
|
||||
per_tool_timeout_ms=self._per_tool_timeout_ms,
|
||||
llm_timeout_ms=self._llm_timeout_ms,
|
||||
end_to_end_timeout_ms=self._end_to_end_timeout_ms,
|
||||
)
|
||||
|
||||
async def execute_with_timeout(
|
||||
self,
|
||||
coro: Callable[[], T],
|
||||
timeout_ms: int | None = None,
|
||||
operation_name: str = "operation",
|
||||
) -> TimeoutResult:
|
||||
"""
|
||||
[AC-MARH-08, AC-MARH-09] Execute a coroutine with timeout.
|
||||
|
||||
Args:
|
||||
coro: Coroutine to execute
|
||||
timeout_ms: Timeout in milliseconds (defaults to per-tool timeout)
|
||||
operation_name: Name for logging
|
||||
|
||||
Returns:
|
||||
TimeoutResult with success status and result/error
|
||||
"""
|
||||
import time
|
||||
|
||||
timeout_ms = timeout_ms or self._per_tool_timeout_ms
|
||||
timeout_seconds = timeout_ms / 1000.0
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(coro(), timeout=timeout_seconds)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.debug(
|
||||
f"[AC-MARH-08] {operation_name} completed in {duration_ms}ms"
|
||||
)
|
||||
|
||||
return TimeoutResult(
|
||||
success=True,
|
||||
result=result,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-MARH-08] {operation_name} timed out after {duration_ms}ms "
|
||||
f"(limit: {timeout_ms}ms)"
|
||||
)
|
||||
return TimeoutResult(
|
||||
success=False,
|
||||
error=f"Timeout after {timeout_ms}ms",
|
||||
duration_ms=duration_ms,
|
||||
timed_out=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-MARH-08] {operation_name} failed: {e}"
|
||||
)
|
||||
return TimeoutResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_func: Callable[[], T],
|
||||
) -> TimeoutResult:
|
||||
"""
|
||||
[AC-MARH-08] Execute a tool call with per-tool timeout.
|
||||
"""
|
||||
return await self.execute_with_timeout(
|
||||
coro=tool_func,
|
||||
timeout_ms=self._per_tool_timeout_ms,
|
||||
operation_name=f"tool:{tool_name}",
|
||||
)
|
||||
|
||||
async def execute_e2e(
|
||||
self,
|
||||
coro: Callable[[], T],
|
||||
operation_name: str = "e2e",
|
||||
) -> TimeoutResult:
|
||||
"""
|
||||
[AC-MARH-09] Execute with end-to-end timeout.
|
||||
"""
|
||||
return await self.execute_with_timeout(
|
||||
coro=coro,
|
||||
timeout_ms=self._end_to_end_timeout_ms,
|
||||
operation_name=operation_name,
|
||||
)
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""
|
||||
Tool Call Recorder for Mid Platform.
|
||||
[AC-IDMP-15] 工具调用结构化记录
|
||||
|
||||
Reference:
|
||||
- spec/intent-driven-mid-platform/openapi.provider.yaml - ToolCallTrace
|
||||
- spec/intent-driven-mid-platform/requirements.md AC-IDMP-15
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.tool_trace import (
|
||||
ToolCallBuilder,
|
||||
ToolCallStatus,
|
||||
ToolCallTrace,
|
||||
ToolType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallStatistics:
|
||||
"""
|
||||
工具调用统计信息
|
||||
"""
|
||||
total_calls: int = 0
|
||||
success_calls: int = 0
|
||||
timeout_calls: int = 0
|
||||
error_calls: int = 0
|
||||
rejected_calls: int = 0
|
||||
total_duration_ms: int = 0
|
||||
avg_duration_ms: float = 0.0
|
||||
max_duration_ms: int = 0
|
||||
min_duration_ms: int = 0
|
||||
|
||||
def update(self, trace: ToolCallTrace) -> None:
|
||||
self.total_calls += 1
|
||||
self.total_duration_ms += trace.duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
|
||||
if trace.duration_ms > self.max_duration_ms:
|
||||
self.max_duration_ms = trace.duration_ms
|
||||
if self.min_duration_ms == 0 or trace.duration_ms < self.min_duration_ms:
|
||||
self.min_duration_ms = trace.duration_ms
|
||||
|
||||
if trace.status == ToolCallStatus.OK:
|
||||
self.success_calls += 1
|
||||
elif trace.status == ToolCallStatus.TIMEOUT:
|
||||
self.timeout_calls += 1
|
||||
elif trace.status == ToolCallStatus.REJECTED:
|
||||
self.rejected_calls += 1
|
||||
else:
|
||||
self.error_calls += 1
|
||||
|
||||
|
||||
class ToolCallRecorder:
|
||||
"""
|
||||
[AC-IDMP-15] 工具调用记录器
|
||||
|
||||
功能:
|
||||
1. 记录每次工具调用的完整信息(参数摘要、耗时、状态、错误码)
|
||||
2. 支持敏感参数脱敏
|
||||
3. 提供统计信息
|
||||
|
||||
记录字段:
|
||||
- tool_name: 工具名称
|
||||
- tool_type: 工具类型 (internal | mcp)
|
||||
- registry_version: 注册表版本
|
||||
- auth_applied: 是否应用鉴权
|
||||
- duration_ms: 调用耗时
|
||||
- status: 调用状态 (ok | timeout | error | rejected)
|
||||
- error_code: 错误码
|
||||
- args_digest: 参数摘要(脱敏)
|
||||
- result_digest: 结果摘要
|
||||
"""
|
||||
|
||||
def __init__(self, max_traces_per_session: int = 100):
|
||||
self._max_traces_per_session = max_traces_per_session
|
||||
self._traces: dict[str, list[ToolCallTrace]] = defaultdict(list)
|
||||
self._statistics: dict[str, ToolCallStatistics] = defaultdict(ToolCallStatistics)
|
||||
|
||||
def start_trace(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_type: ToolType = ToolType.INTERNAL,
|
||||
) -> ToolCallBuilder:
|
||||
"""
|
||||
开始记录一次工具调用
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
ToolCallBuilder: 构建器,用于逐步记录调用信息
|
||||
"""
|
||||
return ToolCallBuilder(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool_type,
|
||||
)
|
||||
|
||||
def record(
|
||||
self,
|
||||
session_id: str,
|
||||
trace: ToolCallTrace,
|
||||
) -> None:
|
||||
"""
|
||||
记录一次工具调用的完整信息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
trace: 工具调用追踪记录
|
||||
"""
|
||||
session_traces = self._traces[session_id]
|
||||
session_traces.append(trace)
|
||||
|
||||
if len(session_traces) > self._max_traces_per_session:
|
||||
session_traces.pop(0)
|
||||
|
||||
self._statistics[trace.tool_name].update(trace)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-15] Tool call recorded: tool={trace.tool_name}, "
|
||||
f"type={trace.tool_type.value}, duration_ms={trace.duration_ms}, "
|
||||
f"status={trace.status.value}, session={session_id}"
|
||||
)
|
||||
|
||||
def record_success(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_name: str,
|
||||
tool_type: ToolType,
|
||||
duration_ms: int,
|
||||
args: Any = None,
|
||||
result: Any = None,
|
||||
registry_version: str | None = None,
|
||||
auth_applied: bool = False,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
记录成功的工具调用(便捷方法)
|
||||
"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool_type,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK,
|
||||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
result_digest=ToolCallTrace.compute_digest(result) if result else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
||||
def record_timeout(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_name: str,
|
||||
tool_type: ToolType,
|
||||
duration_ms: int,
|
||||
args: Any = None,
|
||||
registry_version: str | None = None,
|
||||
auth_applied: bool = False,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
记录超时的工具调用(便捷方法)
|
||||
"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool_type,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TIMEOUT",
|
||||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
||||
def record_error(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_name: str,
|
||||
tool_type: ToolType,
|
||||
duration_ms: int,
|
||||
error_code: str,
|
||||
error_message: str | None = None,
|
||||
args: Any = None,
|
||||
registry_version: str | None = None,
|
||||
auth_applied: bool = False,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
记录错误的工具调用(便捷方法)
|
||||
"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool_type,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code=error_code,
|
||||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
||||
def record_rejected(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_name: str,
|
||||
tool_type: ToolType,
|
||||
reason: str,
|
||||
args: Any = None,
|
||||
registry_version: str | None = None,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
记录被拒绝的工具调用(便捷方法)
|
||||
"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool_type,
|
||||
duration_ms=0,
|
||||
status=ToolCallStatus.REJECTED,
|
||||
error_code=reason,
|
||||
registry_version=registry_version,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
||||
def get_traces(self, session_id: str) -> list[ToolCallTrace]:
|
||||
"""
|
||||
获取指定会话的所有工具调用记录
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
list[ToolCallTrace]: 工具调用记录列表
|
||||
"""
|
||||
return self._traces.get(session_id, [])
|
||||
|
||||
def get_statistics(self, tool_name: str | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
获取工具调用统计信息
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称(可选,不提供则返回所有统计)
|
||||
|
||||
Returns:
|
||||
dict: 统计信息
|
||||
"""
|
||||
if tool_name:
|
||||
stats = self._statistics.get(tool_name)
|
||||
if stats:
|
||||
return {
|
||||
"tool_name": tool_name,
|
||||
"total_calls": stats.total_calls,
|
||||
"success_rate": stats.success_calls / stats.total_calls if stats.total_calls > 0 else 0,
|
||||
"timeout_rate": stats.timeout_calls / stats.total_calls if stats.total_calls > 0 else 0,
|
||||
"error_rate": stats.error_calls / stats.total_calls if stats.total_calls > 0 else 0,
|
||||
"avg_duration_ms": stats.avg_duration_ms,
|
||||
"max_duration_ms": stats.max_duration_ms,
|
||||
"min_duration_ms": stats.min_duration_ms,
|
||||
}
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: {
|
||||
"total_calls": stats.total_calls,
|
||||
"success_rate": stats.success_calls / stats.total_calls if stats.total_calls > 0 else 0,
|
||||
"avg_duration_ms": stats.avg_duration_ms,
|
||||
}
|
||||
for name, stats in self._statistics.items()
|
||||
}
|
||||
|
||||
def clear_session(self, session_id: str) -> int:
|
||||
"""
|
||||
清除指定会话的记录
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 清除的记录数
|
||||
"""
|
||||
if session_id in self._traces:
|
||||
count = len(self._traces[session_id])
|
||||
del self._traces[session_id]
|
||||
return count
|
||||
return 0
|
||||
|
||||
def to_trace_info_format(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
转换为 TraceInfo.tool_calls 格式
|
||||
|
||||
用于输出到 DialogueResponse.trace.tool_calls
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
list[dict]: 符合 OpenAPI 格式的工具调用列表
|
||||
"""
|
||||
traces = self.get_traces(session_id)
|
||||
return [trace.to_dict() for trace in traces]
|
||||
|
||||
|
||||
_recorder: ToolCallRecorder | None = None
|
||||
|
||||
|
||||
def get_tool_call_recorder() -> ToolCallRecorder:
|
||||
"""获取全局工具调用记录器实例"""
|
||||
global _recorder
|
||||
if _recorder is None:
|
||||
_recorder = ToolCallRecorder()
|
||||
return _recorder
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Tool Registry for Mid Platform.
|
||||
[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
ToolCallStatus,
|
||||
ToolCallTrace,
|
||||
ToolType,
|
||||
)
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Tool definition for registry."""
|
||||
name: str
|
||||
description: str
|
||||
tool_type: ToolType = ToolType.INTERNAL
|
||||
version: str = "1.0.0"
|
||||
enabled: bool = True
|
||||
auth_required: bool = False
|
||||
timeout_ms: int = 2000
|
||||
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutionResult:
|
||||
"""Tool execution result."""
|
||||
success: bool
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
auth_applied: bool = False
|
||||
registry_version: str | None = None
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
[AC-IDMP-19] Unified tool registry for governance.
|
||||
|
||||
Features:
|
||||
- Tool registration with metadata
|
||||
- Auth policy enforcement
|
||||
- Timeout governance
|
||||
- Version management
|
||||
- Enable/disable control
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
):
|
||||
self._tools: dict[str, ToolDefinition] = {}
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._version = "1.0.0"
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Get registry version."""
|
||||
return self._version
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]],
|
||||
tool_type: ToolType = ToolType.INTERNAL,
|
||||
version: str = "1.0.0",
|
||||
auth_required: bool = False,
|
||||
timeout_ms: int = 2000,
|
||||
enabled: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ToolDefinition:
|
||||
"""
|
||||
[AC-IDMP-19] Register a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name (unique identifier)
|
||||
description: Tool description
|
||||
handler: Async handler function
|
||||
tool_type: Tool type (internal/mcp)
|
||||
version: Tool version
|
||||
auth_required: Whether auth is required
|
||||
timeout_ms: Tool-specific timeout
|
||||
enabled: Whether tool is enabled
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
ToolDefinition for the registered tool
|
||||
"""
|
||||
if name in self._tools:
|
||||
logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}")
|
||||
|
||||
tool = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
tool_type=tool_type,
|
||||
version=version,
|
||||
enabled=enabled,
|
||||
auth_required=auth_required,
|
||||
timeout_ms=min(timeout_ms, 2000),
|
||||
handler=handler,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._tools[name] = tool
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
|
||||
f"version={version}, auth_required={auth_required}"
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Unregister a tool."""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_tool(self, name: str) -> ToolDefinition | None:
|
||||
"""Get tool definition by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tool_type: ToolType | None = None,
|
||||
enabled_only: bool = True,
|
||||
) -> list[ToolDefinition]:
|
||||
"""List registered tools, optionally filtered."""
|
||||
tools = list(self._tools.values())
|
||||
|
||||
if tool_type:
|
||||
tools = [t for t in tools if t.tool_type == tool_type]
|
||||
|
||||
if enabled_only:
|
||||
tools = [t for t in tools if t.enabled]
|
||||
|
||||
return tools
|
||||
|
||||
def enable_tool(self, name: str) -> bool:
|
||||
"""Enable a tool."""
|
||||
tool = self._tools.get(name)
|
||||
if tool:
|
||||
tool.enabled = True
|
||||
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_tool(self, name: str) -> bool:
|
||||
"""Disable a tool."""
|
||||
tool = self._tools.get(name)
|
||||
if tool:
|
||||
tool.enabled = False
|
||||
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
auth_context: dict[str, Any] | None = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
[AC-IDMP-19] Execute a tool with governance.
|
||||
|
||||
Args:
|
||||
tool_name: Tool name to execute
|
||||
args: Tool arguments
|
||||
auth_context: Authentication context
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult with output and metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
tool = self._tools.get(tool_name)
|
||||
if not tool:
|
||||
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool not found: {tool_name}",
|
||||
duration_ms=0,
|
||||
)
|
||||
|
||||
if not tool.enabled:
|
||||
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool disabled: {tool_name}",
|
||||
duration_ms=0,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
|
||||
auth_applied = False
|
||||
if tool.auth_required:
|
||||
if not auth_context:
|
||||
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error="Authentication required",
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
auth_applied=False,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
auth_applied = True
|
||||
|
||||
try:
|
||||
timeout_seconds = tool.timeout_ms / 1000.0
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
tool.handler(**args) if tool.handler else asyncio.sleep(0),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
|
||||
f"duration_ms={duration_ms}, success=True"
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
success=True,
|
||||
output=result,
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
|
||||
f"duration_ms={duration_ms}"
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool timeout after {tool.timeout_ms}ms",
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
|
||||
def create_trace(
|
||||
self,
|
||||
tool_name: str,
|
||||
result: ToolExecutionResult,
|
||||
args_digest: str | None = None,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
[AC-IDMP-19] Create ToolCallTrace from execution result.
|
||||
"""
|
||||
tool = self._tools.get(tool_name)
|
||||
|
||||
return ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
|
||||
registry_version=result.registry_version,
|
||||
auth_applied=result.auth_applied,
|
||||
duration_ms=result.duration_ms,
|
||||
status=ToolCallStatus.OK if result.success else (
|
||||
ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower()
|
||||
else ToolCallStatus.ERROR
|
||||
),
|
||||
error_code=result.error if not result.success else None,
|
||||
args_digest=args_digest,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
)
|
||||
|
||||
def get_governance_report(self) -> dict[str, Any]:
|
||||
"""Get governance report for all tools."""
|
||||
return {
|
||||
"registry_version": self._version,
|
||||
"total_tools": len(self._tools),
|
||||
"enabled_tools": sum(1 for t in self._tools.values() if t.enabled),
|
||||
"disabled_tools": sum(1 for t in self._tools.values() if not t.enabled),
|
||||
"auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required),
|
||||
"mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP),
|
||||
"internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL),
|
||||
"tools": [
|
||||
{
|
||||
"name": t.name,
|
||||
"type": t.tool_type.value,
|
||||
"version": t.version,
|
||||
"enabled": t.enabled,
|
||||
"auth_required": t.auth_required,
|
||||
"timeout_ms": t.timeout_ms,
|
||||
}
|
||||
for t in self._tools.values()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
_registry: ToolRegistry | None = None
|
||||
|
||||
|
||||
def get_tool_registry() -> ToolRegistry:
|
||||
"""Get global tool registry instance."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = ToolRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry:
|
||||
"""Initialize and return tool registry."""
|
||||
global _registry
|
||||
_registry = ToolRegistry(timeout_governor=timeout_governor)
|
||||
return _registry
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
"""
|
||||
Trace Logger for Mid Platform.
|
||||
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace collection and audit logging.
|
||||
|
||||
Audit Fields:
|
||||
- session_id, request_id, generation_id
|
||||
- mode, intent, tool_calls
|
||||
- guardrail_triggered, guardrail_rule_id
|
||||
- interrupt_consumed
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
ExecutionMode,
|
||||
ToolCallTrace,
|
||||
TraceInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditRecord:
|
||||
"""[AC-MARH-12] Audit record for database persistence."""
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
request_id: str
|
||||
generation_id: str
|
||||
mode: ExecutionMode
|
||||
intent: str | None = None
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
guardrail_triggered: bool = False
|
||||
guardrail_rule_id: str | None = None
|
||||
interrupt_consumed: bool = False
|
||||
fallback_reason_code: str | None = None
|
||||
react_iterations: int = 0
|
||||
latency_ms: int = 0
|
||||
created_at: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"session_id": self.session_id,
|
||||
"request_id": self.request_id,
|
||||
"generation_id": self.generation_id,
|
||||
"mode": self.mode.value,
|
||||
"intent": self.intent,
|
||||
"tool_calls": self.tool_calls,
|
||||
"guardrail_triggered": self.guardrail_triggered,
|
||||
"guardrail_rule_id": self.guardrail_rule_id,
|
||||
"interrupt_consumed": self.interrupt_consumed,
|
||||
"fallback_reason_code": self.fallback_reason_code,
|
||||
"react_iterations": self.react_iterations,
|
||||
"latency_ms": self.latency_ms,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class TraceLogger:
|
||||
"""
|
||||
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace logger for observability and audit.
|
||||
|
||||
Features:
|
||||
- Request-scoped trace context
|
||||
- Tool call tracing
|
||||
- Guardrail event logging with rule_id
|
||||
- Interrupt consumption tracking
|
||||
- Audit record generation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._traces: dict[str, TraceInfo] = {}
|
||||
self._audit_records: list[AuditRecord] = []
|
||||
|
||||
def start_trace(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
request_id: str | None = None,
|
||||
generation_id: str | None = None,
|
||||
) -> TraceInfo:
|
||||
"""
|
||||
[AC-MARH-12] Start a new trace context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
session_id: Session ID
|
||||
request_id: Request ID (auto-generated if not provided)
|
||||
generation_id: Generation ID for interrupt handling
|
||||
|
||||
Returns:
|
||||
TraceInfo for the new trace
|
||||
"""
|
||||
request_id = request_id or str(uuid.uuid4())
|
||||
generation_id = generation_id or str(uuid.uuid4())
|
||||
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=request_id,
|
||||
generation_id=generation_id,
|
||||
)
|
||||
|
||||
self._traces[request_id] = trace
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] Trace started: request_id={request_id}, "
|
||||
f"session_id={session_id}, generation_id={generation_id}"
|
||||
)
|
||||
|
||||
return trace
|
||||
|
||||
def get_trace(self, request_id: str) -> TraceInfo | None:
|
||||
"""Get trace by request ID."""
|
||||
return self._traces.get(request_id)
|
||||
|
||||
def update_trace(
|
||||
self,
|
||||
request_id: str,
|
||||
mode: ExecutionMode | None = None,
|
||||
intent: str | None = None,
|
||||
guardrail_triggered: bool | None = None,
|
||||
guardrail_rule_id: str | None = None,
|
||||
interrupt_consumed: bool | None = None,
|
||||
fallback_reason_code: str | None = None,
|
||||
react_iterations: int | None = None,
|
||||
tool_calls: list[ToolCallTrace] | None = None,
|
||||
) -> TraceInfo | None:
|
||||
"""
|
||||
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Update trace with execution details.
|
||||
"""
|
||||
trace = self._traces.get(request_id)
|
||||
if not trace:
|
||||
logger.warning(f"[AC-MARH-12] Trace not found: {request_id}")
|
||||
return None
|
||||
|
||||
if mode is not None:
|
||||
trace.mode = mode
|
||||
if intent is not None:
|
||||
trace.intent = intent
|
||||
if guardrail_triggered is not None:
|
||||
trace.guardrail_triggered = guardrail_triggered
|
||||
if guardrail_rule_id is not None:
|
||||
trace.guardrail_rule_id = guardrail_rule_id
|
||||
if interrupt_consumed is not None:
|
||||
trace.interrupt_consumed = interrupt_consumed
|
||||
if fallback_reason_code is not None:
|
||||
trace.fallback_reason_code = fallback_reason_code
|
||||
if react_iterations is not None:
|
||||
trace.react_iterations = react_iterations
|
||||
if tool_calls is not None:
|
||||
trace.tool_calls = tool_calls
|
||||
|
||||
return trace
|
||||
|
||||
def add_tool_call(
|
||||
self,
|
||||
request_id: str,
|
||||
tool_call: ToolCallTrace,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-MARH-12] Add tool call trace to request.
|
||||
"""
|
||||
trace = self._traces.get(request_id)
|
||||
if not trace:
|
||||
logger.warning(f"[AC-MARH-12] Trace not found for tool call: {request_id}")
|
||||
return
|
||||
|
||||
if trace.tool_calls is None:
|
||||
trace.tool_calls = []
|
||||
|
||||
trace.tool_calls.append(tool_call)
|
||||
|
||||
if trace.tools_used is None:
|
||||
trace.tools_used = []
|
||||
|
||||
if tool_call.tool_name not in trace.tools_used:
|
||||
trace.tools_used.append(tool_call.tool_name)
|
||||
|
||||
logger.debug(
|
||||
f"[AC-MARH-12] Tool call recorded: request_id={request_id}, "
|
||||
f"tool={tool_call.tool_name}, status={tool_call.status.value}"
|
||||
)
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
request_id: str,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
latency_ms: int,
|
||||
) -> AuditRecord:
|
||||
"""
|
||||
[AC-MARH-12] End trace and create audit record.
|
||||
"""
|
||||
trace = self._traces.get(request_id)
|
||||
if not trace:
|
||||
logger.warning(f"[AC-MARH-12] Trace not found for end: {request_id}")
|
||||
return AuditRecord(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
request_id=request_id,
|
||||
generation_id=str(uuid.uuid4()),
|
||||
mode=ExecutionMode.FIXED,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
audit = AuditRecord(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
request_id=request_id,
|
||||
generation_id=trace.generation_id or str(uuid.uuid4()),
|
||||
mode=trace.mode,
|
||||
intent=trace.intent,
|
||||
tool_calls=[tc.model_dump() for tc in trace.tool_calls] if trace.tool_calls else [],
|
||||
guardrail_triggered=trace.guardrail_triggered or False,
|
||||
guardrail_rule_id=trace.guardrail_rule_id,
|
||||
interrupt_consumed=trace.interrupt_consumed or False,
|
||||
fallback_reason_code=trace.fallback_reason_code,
|
||||
react_iterations=trace.react_iterations or 0,
|
||||
latency_ms=latency_ms,
|
||||
created_at=time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
self._audit_records.append(audit)
|
||||
|
||||
if request_id in self._traces:
|
||||
del self._traces[request_id]
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-12] Trace ended: request_id={request_id}, "
|
||||
f"mode={trace.mode.value}, latency_ms={latency_ms}"
|
||||
)
|
||||
|
||||
return audit
|
||||
|
||||
def get_audit_records(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[AuditRecord]:
|
||||
"""Get audit records for a tenant/session."""
|
||||
records = [
|
||||
r for r in self._audit_records
|
||||
if r.tenant_id == tenant_id
|
||||
]
|
||||
|
||||
if session_id:
|
||||
records = [r for r in records if r.session_id == session_id]
|
||||
|
||||
return records[-limit:]
|
||||
|
||||
def clear_audit_records(self, tenant_id: str | None = None) -> int:
|
||||
"""Clear audit records, optionally filtered by tenant."""
|
||||
if tenant_id:
|
||||
original_count = len(self._audit_records)
|
||||
self._audit_records = [
|
||||
r for r in self._audit_records
|
||||
if r.tenant_id != tenant_id
|
||||
]
|
||||
return original_count - len(self._audit_records)
|
||||
else:
|
||||
count = len(self._audit_records)
|
||||
self._audit_records = []
|
||||
return count
|
||||
Loading…
Reference in New Issue