test: add mid-platform service tests [AC-IDMP-01~20, AC-MARH-01~12]
- Add memory tool tests - Add mid-platform services tests
This commit is contained in:
parent
c7c94e8cd9
commit
38130f7a27
|
|
@ -0,0 +1,550 @@
|
|||
"""
|
||||
Tests for Mid Platform Memory and Tool Governance services.
|
||||
[AC-IDMP-13/14/15/19] 记忆与工具治理测试
|
||||
|
||||
覆盖路径:
|
||||
- 成功路径:正常 recall/update,工具调用成功
|
||||
- 超时路径:recall 超时降级,工具调用超时
|
||||
- 错误路径:recall 失败降级,工具调用错误
|
||||
- 降级路径:记忆服务不可用时继续主链路
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.mid.memory_adapter import MemoryAdapter, UserMemory
|
||||
from app.services.mid.tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder
|
||||
from app.services.mid.tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry
|
||||
from app.models.mid.memory import (
|
||||
RecallRequest,
|
||||
RecallResponse,
|
||||
UpdateRequest,
|
||||
MemoryProfile,
|
||||
MemoryFact,
|
||||
MemoryPreferences,
|
||||
)
|
||||
from app.models.mid.tool_trace import (
|
||||
ToolCallTrace,
|
||||
ToolCallBuilder,
|
||||
ToolCallStatus,
|
||||
ToolType,
|
||||
)
|
||||
from app.models.mid.schemas import ToolCallTrace as ToolCallTraceSchema
|
||||
|
||||
|
||||
class TestMemoryAdapter:
|
||||
"""[AC-IDMP-13/14] 记忆适配器测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_session):
|
||||
return MemoryAdapter(mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_success(self, adapter):
|
||||
"""[AC-IDMP-13] 成功召回用户记忆"""
|
||||
response = await adapter.recall(
|
||||
user_id="user123",
|
||||
session_id="session456",
|
||||
tenant_id="tenant789",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.profile is not None
|
||||
assert response.profile.grade == "初一"
|
||||
assert response.profile.region == "北京"
|
||||
assert len(response.facts) > 0
|
||||
assert response.preferences is not None
|
||||
assert response.preferences.tone == "friendly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_returns_context_for_prompt(self, adapter):
|
||||
"""[AC-IDMP-13] 召回记忆可注入 Prompt"""
|
||||
response = await adapter.recall(
|
||||
user_id="user123",
|
||||
session_id="session456",
|
||||
)
|
||||
|
||||
context = response.get_context_for_prompt()
|
||||
|
||||
assert "【用户属性】" in context
|
||||
assert "年级" in context
|
||||
assert "【已知事实】" in context
|
||||
assert "【用户偏好】" in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_timeout_fallback(self, mock_session):
|
||||
"""[AC-IDMP-13] recall 超时降级,不阻断主链路"""
|
||||
adapter = MemoryAdapter(mock_session, recall_timeout_ms=1)
|
||||
|
||||
async def slow_recall(*args, **kwargs):
|
||||
await asyncio.sleep(1)
|
||||
return RecallResponse()
|
||||
|
||||
with patch.object(adapter, '_recall_internal', slow_recall):
|
||||
response = await adapter.recall(
|
||||
user_id="user123",
|
||||
session_id="session456",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.profile is None
|
||||
assert response.facts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_error_fallback(self, adapter, mock_session):
|
||||
"""[AC-IDMP-13] recall 错误降级,不阻断主链路"""
|
||||
with patch.object(adapter, '_recall_internal', side_effect=Exception("DB error")):
|
||||
response = await adapter.recall(
|
||||
user_id="user123",
|
||||
session_id="session456",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.profile is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_async(self, adapter):
|
||||
"""[AC-IDMP-14] 异步更新用户记忆"""
|
||||
result = await adapter.update(
|
||||
user_id="user123",
|
||||
session_id="session456",
|
||||
messages=[{"role": "user", "content": "你好"}],
|
||||
summary="用户打招呼",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
completed = await adapter.wait_pending_updates(timeout=1.0)
|
||||
assert completed >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_summary_generation(self, adapter):
|
||||
"""[AC-IDMP-14] 带摘要生成的记忆更新"""
|
||||
async def summary_gen(messages):
|
||||
return "这是一个测试摘要"
|
||||
|
||||
result = await adapter.update_with_summary_generation(
|
||||
user_id="user123",
|
||||
session_id="session456",
|
||||
messages=[{"role": "user", "content": "你好"}],
|
||||
summary_generator=summary_gen,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestToolCallRecorder:
|
||||
"""[AC-IDMP-15] 工具调用记录器测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self):
|
||||
return ToolCallRecorder()
|
||||
|
||||
def test_record_success(self, recorder):
|
||||
"""[AC-IDMP-15] 记录成功的工具调用"""
|
||||
trace = recorder.record_success(
|
||||
session_id="session123",
|
||||
tool_name="kb_search",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=150,
|
||||
args={"query": "测试查询"},
|
||||
result={"docs": ["doc1", "doc2"]},
|
||||
registry_version="1.0.0",
|
||||
auth_applied=False,
|
||||
)
|
||||
|
||||
assert trace.tool_name == "kb_search"
|
||||
assert trace.status == ToolCallStatus.OK
|
||||
assert trace.duration_ms == 150
|
||||
assert trace.args_digest is not None
|
||||
assert trace.result_digest is not None
|
||||
|
||||
def test_record_timeout(self, recorder):
|
||||
"""[AC-IDMP-15] 记录超时的工具调用"""
|
||||
trace = recorder.record_timeout(
|
||||
session_id="session123",
|
||||
tool_name="kb_search",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=2000,
|
||||
args={"query": "测试查询"},
|
||||
)
|
||||
|
||||
assert trace.status == ToolCallStatus.TIMEOUT
|
||||
assert trace.error_code == "TIMEOUT"
|
||||
|
||||
def test_record_error(self, recorder):
|
||||
"""[AC-IDMP-15] 记录错误的工具调用"""
|
||||
trace = recorder.record_error(
|
||||
session_id="session123",
|
||||
tool_name="kb_search",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=500,
|
||||
error_code="DB_CONNECTION_ERROR",
|
||||
error_message="Database connection failed",
|
||||
)
|
||||
|
||||
assert trace.status == ToolCallStatus.ERROR
|
||||
assert trace.error_code == "DB_CONNECTION_ERROR"
|
||||
|
||||
def test_record_rejected(self, recorder):
|
||||
"""[AC-IDMP-15] 记录被拒绝的工具调用"""
|
||||
trace = recorder.record_rejected(
|
||||
session_id="session123",
|
||||
tool_name="sensitive_tool",
|
||||
tool_type=ToolType.MCP,
|
||||
reason="AUTH_REQUIRED",
|
||||
)
|
||||
|
||||
assert trace.status == ToolCallStatus.REJECTED
|
||||
assert trace.error_code == "AUTH_REQUIRED"
|
||||
|
||||
def test_get_traces(self, recorder):
|
||||
"""[AC-IDMP-15] 获取会话的工具调用记录"""
|
||||
recorder.record_success(
|
||||
session_id="session123",
|
||||
tool_name="tool1",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=100,
|
||||
)
|
||||
recorder.record_success(
|
||||
session_id="session123",
|
||||
tool_name="tool2",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=200,
|
||||
)
|
||||
|
||||
traces = recorder.get_traces("session123")
|
||||
|
||||
assert len(traces) == 2
|
||||
assert traces[0].tool_name == "tool1"
|
||||
assert traces[1].tool_name == "tool2"
|
||||
|
||||
def test_get_statistics(self, recorder):
|
||||
"""[AC-IDMP-15] 获取工具调用统计"""
|
||||
recorder.record_success(
|
||||
session_id="session123",
|
||||
tool_name="kb_search",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=100,
|
||||
)
|
||||
recorder.record_timeout(
|
||||
session_id="session123",
|
||||
tool_name="kb_search",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=2000,
|
||||
)
|
||||
|
||||
stats = recorder.get_statistics("kb_search")
|
||||
|
||||
assert stats["total_calls"] == 2
|
||||
assert stats["success_rate"] == 0.5
|
||||
assert stats["timeout_rate"] == 0.5
|
||||
|
||||
def test_to_trace_info_format(self, recorder):
|
||||
"""[AC-IDMP-15] 转换为 TraceInfo 格式"""
|
||||
recorder.record_success(
|
||||
session_id="session123",
|
||||
tool_name="kb_search",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=100,
|
||||
)
|
||||
|
||||
traces = recorder.to_trace_info_format("session123")
|
||||
|
||||
assert len(traces) == 1
|
||||
assert traces[0]["tool_name"] == "kb_search"
|
||||
assert traces[0]["status"] == "ok"
|
||||
|
||||
def test_compute_digest(self, recorder):
|
||||
"""[AC-IDMP-15] 敏感参数脱敏"""
|
||||
long_data = {"query": "x" * 100}
|
||||
digest = ToolCallTrace.compute_digest(long_data)
|
||||
|
||||
assert len(digest) < 100
|
||||
assert "..." in digest or len(digest) <= 64
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""[AC-IDMP-19] Tool Registry 治理测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def registry(self):
|
||||
return ToolRegistry()
|
||||
|
||||
def test_register_tool(self, registry):
|
||||
"""[AC-IDMP-19] 注册工具"""
|
||||
async def dummy_handler(**kwargs):
|
||||
return {"result": "ok"}
|
||||
|
||||
tool = registry.register(
|
||||
name="test_tool",
|
||||
description="测试工具",
|
||||
handler=dummy_handler,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
auth_required=False,
|
||||
timeout_ms=1000,
|
||||
)
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.enabled is True
|
||||
assert tool.timeout_ms == 1000
|
||||
|
||||
def test_register_mcp_tool(self, registry):
|
||||
"""[AC-IDMP-19] 注册 MCP 工具"""
|
||||
async def mcp_handler(**kwargs):
|
||||
return {"result": "ok"}
|
||||
|
||||
tool = registry.register(
|
||||
name="mcp_tool",
|
||||
description="MCP 工具",
|
||||
handler=mcp_handler,
|
||||
tool_type=ToolType.MCP,
|
||||
)
|
||||
|
||||
assert tool.tool_type == ToolType.MCP
|
||||
|
||||
def test_enable_disable_tool(self, registry):
|
||||
"""[AC-IDMP-19] 启停工具"""
|
||||
async def handler(**kwargs):
|
||||
return {}
|
||||
|
||||
registry.register(
|
||||
name="toggle_tool",
|
||||
description="可切换工具",
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
assert registry.get_tool("toggle_tool").enabled is True
|
||||
|
||||
registry.disable_tool("toggle_tool")
|
||||
assert registry.get_tool("toggle_tool").enabled is False
|
||||
|
||||
registry.enable_tool("toggle_tool")
|
||||
assert registry.get_tool("toggle_tool").enabled is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_success(self, registry):
|
||||
"""[AC-IDMP-19] 执行工具成功"""
|
||||
async def handler(**kwargs):
|
||||
return {"data": kwargs.get("query")}
|
||||
|
||||
registry.register(
|
||||
name="search",
|
||||
description="搜索工具",
|
||||
handler=handler,
|
||||
timeout_ms=1000,
|
||||
)
|
||||
|
||||
result = await registry.execute(
|
||||
tool_name="search",
|
||||
args={"query": "test"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["data"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_timeout(self, registry):
|
||||
"""[AC-IDMP-19] 工具执行超时"""
|
||||
async def slow_handler(**kwargs):
|
||||
await asyncio.sleep(5)
|
||||
return {"data": "ok"}
|
||||
|
||||
registry.register(
|
||||
name="slow_tool",
|
||||
description="慢工具",
|
||||
handler=slow_handler,
|
||||
timeout_ms=100,
|
||||
)
|
||||
|
||||
result = await registry.execute(
|
||||
tool_name="slow_tool",
|
||||
args={},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "timeout" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_not_found(self, registry):
|
||||
"""[AC-IDMP-19] 工具不存在"""
|
||||
result = await registry.execute(
|
||||
tool_name="nonexistent",
|
||||
args={},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_disabled(self, registry):
|
||||
"""[AC-IDMP-19] 工具已禁用"""
|
||||
async def handler(**kwargs):
|
||||
return {}
|
||||
|
||||
registry.register(
|
||||
name="disabled_tool",
|
||||
description="禁用工具",
|
||||
handler=handler,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
result = await registry.execute(
|
||||
tool_name="disabled_tool",
|
||||
args={},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "disabled" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_auth_required(self, registry):
|
||||
"""[AC-IDMP-19] 工具需要鉴权"""
|
||||
async def handler(**kwargs):
|
||||
return {}
|
||||
|
||||
registry.register(
|
||||
name="auth_tool",
|
||||
description="需要鉴权的工具",
|
||||
handler=handler,
|
||||
auth_required=True,
|
||||
)
|
||||
|
||||
result = await registry.execute(
|
||||
tool_name="auth_tool",
|
||||
args={},
|
||||
auth_context=None,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "auth" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_with_auth_context(self, registry):
|
||||
"""[AC-IDMP-19] 带鉴权上下文执行工具"""
|
||||
async def handler(**kwargs):
|
||||
return {"authenticated": True}
|
||||
|
||||
registry.register(
|
||||
name="auth_tool2",
|
||||
description="需要鉴权的工具",
|
||||
handler=handler,
|
||||
auth_required=True,
|
||||
)
|
||||
|
||||
result = await registry.execute(
|
||||
tool_name="auth_tool2",
|
||||
args={},
|
||||
auth_context={"user_id": "user123"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.auth_applied is True
|
||||
|
||||
def test_create_trace(self, registry):
|
||||
"""[AC-IDMP-19] 创建工具调用追踪"""
|
||||
result = ToolExecutionResult(
|
||||
success=True,
|
||||
output={"data": "ok"},
|
||||
duration_ms=100,
|
||||
auth_applied=False,
|
||||
registry_version="1.0.0",
|
||||
)
|
||||
|
||||
trace = registry.create_trace(
|
||||
tool_name="test_tool",
|
||||
result=result,
|
||||
args_digest="query=test",
|
||||
)
|
||||
|
||||
assert trace.tool_name == "test_tool"
|
||||
assert trace.status == ToolCallStatus.OK
|
||||
|
||||
def test_get_governance_report(self, registry):
|
||||
"""[AC-IDMP-19] 获取治理报告"""
|
||||
async def handler(**kwargs):
|
||||
return {}
|
||||
|
||||
registry.register(name="tool1", description="工具1", handler=handler)
|
||||
registry.register(name="tool2", description="工具2", handler=handler, enabled=False)
|
||||
registry.register(name="tool3", description="工具3", handler=handler, auth_required=True)
|
||||
|
||||
report = registry.get_governance_report()
|
||||
|
||||
assert report["total_tools"] == 3
|
||||
assert report["enabled_tools"] == 2
|
||||
assert report["disabled_tools"] == 1
|
||||
assert report["auth_required_tools"] == 1
|
||||
|
||||
def test_get_tool_registry_singleton(self):
|
||||
"""[AC-IDMP-19] 获取全局注册表实例"""
|
||||
registry1 = get_tool_registry()
|
||||
registry2 = get_tool_registry()
|
||||
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_init_tool_registry(self):
|
||||
"""[AC-IDMP-19] 初始化注册表"""
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
governor = TimeoutGovernor()
|
||||
registry = init_tool_registry(timeout_governor=governor)
|
||||
|
||||
assert registry is not None
|
||||
assert registry._timeout_governor is governor
|
||||
|
||||
|
||||
class TestToolCallBuilder:
|
||||
"""[AC-IDMP-15] 工具调用构建器测试"""
|
||||
|
||||
def test_build_success_trace(self):
|
||||
"""[AC-IDMP-15] 构建成功追踪"""
|
||||
builder = ToolCallBuilder(
|
||||
tool_name="test_tool",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
)
|
||||
|
||||
trace = builder.with_args({"query": "test"}).with_result({"data": "ok"}).build()
|
||||
|
||||
assert trace.tool_name == "test_tool"
|
||||
assert trace.status == ToolCallStatus.OK
|
||||
assert trace.args_digest is not None
|
||||
assert trace.result_digest is not None
|
||||
|
||||
def test_build_error_trace(self):
|
||||
"""[AC-IDMP-15] 构建错误追踪"""
|
||||
builder = ToolCallBuilder(tool_name="test_tool")
|
||||
|
||||
trace = builder.with_error(
|
||||
Exception("Something went wrong"),
|
||||
error_code="INTERNAL_ERROR",
|
||||
).build()
|
||||
|
||||
assert trace.status == ToolCallStatus.ERROR
|
||||
assert trace.error_code == "INTERNAL_ERROR"
|
||||
|
||||
def test_build_timeout_trace(self):
|
||||
"""[AC-IDMP-15] 构建超时追踪"""
|
||||
builder = ToolCallBuilder(tool_name="test_tool")
|
||||
|
||||
trace = builder.with_error(TimeoutError("Timeout")).build()
|
||||
|
||||
assert trace.status == ToolCallStatus.TIMEOUT
|
||||
|
||||
def test_build_rejected_trace(self):
|
||||
"""[AC-IDMP-15] 构建拒绝追踪"""
|
||||
builder = ToolCallBuilder(tool_name="test_tool")
|
||||
|
||||
trace = builder.with_rejected("AUTH_DENIED").build()
|
||||
|
||||
assert trace.status == ToolCallStatus.REJECTED
|
||||
assert trace.error_code == "AUTH_DENIED"
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Tests for Mid Platform services.
|
||||
[AC-IDMP-05/07/08/09/18/20] 中台服务测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.mid import (
|
||||
HighRiskHandler,
|
||||
HighRiskMatchResult,
|
||||
PolicyRouter,
|
||||
PolicyRouteResult,
|
||||
TraceLogger,
|
||||
MetricsCollector,
|
||||
MetricsRecord,
|
||||
SessionModeService,
|
||||
DEFAULT_HIGH_RISK_SCENARIOS,
|
||||
)
|
||||
from app.models.entities import (
|
||||
HighRiskPolicy,
|
||||
HighRiskScenarioType,
|
||||
SessionModeRecord,
|
||||
MidAuditLog,
|
||||
)
|
||||
from app.models.mid import Mode, SessionMode
|
||||
|
||||
|
||||
class TestHighRiskHandler:
|
||||
"""[AC-IDMP-05/20] 高风险场景处理器测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def handler(self, mock_session):
|
||||
return HighRiskHandler(mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_scenario_set_returns_default_when_empty(self, handler, mock_session):
|
||||
"""[AC-IDMP-20] 空集保护:数据库无配置时返回默认最小集"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
scenarios = await handler.get_active_scenario_set("test_tenant")
|
||||
|
||||
assert scenarios == DEFAULT_HIGH_RISK_SCENARIOS
|
||||
assert "refund" in scenarios
|
||||
assert "complaint_escalation" in scenarios
|
||||
assert "privacy_sensitive_promise" in scenarios
|
||||
assert "transfer" in scenarios
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_high_risk_refund_keyword(self, handler, mock_session):
|
||||
"""[AC-IDMP-05] 检测退款场景"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await handler.detect_high_risk("test_tenant", "我要退款")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.scenario == HighRiskScenarioType.REFUND.value
|
||||
assert result.handler_mode == "micro_flow"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_high_risk_transfer_keyword(self, handler, mock_session):
|
||||
"""[AC-IDMP-05] 检测转人工场景"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await handler.detect_high_risk("test_tenant", "转人工")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.scenario == HighRiskScenarioType.TRANSFER.value
|
||||
assert result.handler_mode == "transfer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_high_risk_complaint_keyword(self, handler, mock_session):
|
||||
"""[AC-IDMP-05] 检测投诉升级场景"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await handler.detect_high_risk("test_tenant", "我要投诉你们")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.scenario == HighRiskScenarioType.COMPLAINT_ESCALATION.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_high_risk_privacy_keyword(self, handler, mock_session):
|
||||
"""[AC-IDMP-05] 检测隐私敏感承诺场景"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await handler.detect_high_risk("test_tenant", "你能保证我的隐私安全吗")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.scenario == HighRiskScenarioType.PRIVACY_SENSITIVE_PROMISE.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_high_risk_no_match(self, handler, mock_session):
|
||||
"""[AC-IDMP-05] 正常消息不触发高风险"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await handler.detect_high_risk("test_tenant", "你好,请问有什么可以帮助我的")
|
||||
|
||||
assert result.matched is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_high_risk_with_policy(self, handler, mock_session):
|
||||
"""[AC-IDMP-05] 使用数据库策略检测高风险"""
|
||||
policy = HighRiskPolicy(
|
||||
tenant_id="test_tenant",
|
||||
scenario=HighRiskScenarioType.REFUND.value,
|
||||
handler_mode="micro_flow",
|
||||
keywords=["退货退款"],
|
||||
priority=100,
|
||||
is_enabled=True,
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [policy]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await handler.detect_high_risk("test_tenant", "我要退货退款")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.scenario == HighRiskScenarioType.REFUND.value
|
||||
assert result.matched_keyword == "退货退款"
|
||||
|
||||
|
||||
class TestPolicyRouter:
|
||||
"""[AC-IDMP-05] 策略路由器测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_high_risk_handler(self):
|
||||
return AsyncMock(spec=HighRiskHandler)
|
||||
|
||||
@pytest.fixture
|
||||
def router(self, mock_high_risk_handler):
|
||||
return PolicyRouter(mock_high_risk_handler)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_human_mode_returns_transfer(self, router, mock_high_risk_handler):
|
||||
"""[AC-IDMP-05] 人工模式返回 transfer"""
|
||||
mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(matched=False)
|
||||
|
||||
result = await router.route(
|
||||
tenant_id="test_tenant",
|
||||
user_message="你好",
|
||||
session_mode="HUMAN_ACTIVE",
|
||||
)
|
||||
|
||||
assert result.mode == "transfer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_high_risk_returns_micro_flow(self, router, mock_high_risk_handler):
|
||||
"""[AC-IDMP-05] 高风险场景返回 micro_flow 或 transfer"""
|
||||
mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(
|
||||
matched=True,
|
||||
scenario="refund",
|
||||
handler_mode="micro_flow",
|
||||
)
|
||||
|
||||
result = await router.route(
|
||||
tenant_id="test_tenant",
|
||||
user_message="我要退款",
|
||||
session_mode="BOT_ACTIVE",
|
||||
)
|
||||
|
||||
assert result.mode == "micro_flow"
|
||||
assert result.high_risk_scenario == "refund"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_low_confidence_returns_fixed(self, router, mock_high_risk_handler):
|
||||
"""[AC-IDMP-05] 低置信度返回 fixed"""
|
||||
mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(matched=False)
|
||||
|
||||
result = await router.route(
|
||||
tenant_id="test_tenant",
|
||||
user_message="你好",
|
||||
session_mode="BOT_ACTIVE",
|
||||
confidence=0.2,
|
||||
)
|
||||
|
||||
assert result.mode == "fixed"
|
||||
assert result.fallback_reason_code == "low_confidence"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_normal_returns_agent(self, router, mock_high_risk_handler):
|
||||
"""[AC-IDMP-05] 正常场景返回 agent"""
|
||||
mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(matched=False)
|
||||
|
||||
result = await router.route(
|
||||
tenant_id="test_tenant",
|
||||
user_message="你好",
|
||||
session_mode="BOT_ACTIVE",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
assert result.mode == "agent"
|
||||
|
||||
|
||||
class TestMetricsCollector:
|
||||
"""[AC-IDMP-18] 指标采集器测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def collector(self):
|
||||
return MetricsCollector()
|
||||
|
||||
def test_record_and_get_metrics(self, collector):
|
||||
"""[AC-IDMP-18] 记录并获取指标"""
|
||||
record1 = MetricsRecord(
|
||||
tenant_id="tenant1",
|
||||
session_id="session1",
|
||||
request_id="req1",
|
||||
task_completed=True,
|
||||
slots_filled=3,
|
||||
slots_total=5,
|
||||
was_transferred=False,
|
||||
had_recall=True,
|
||||
latency_ms=100,
|
||||
)
|
||||
|
||||
record2 = MetricsRecord(
|
||||
tenant_id="tenant1",
|
||||
session_id="session2",
|
||||
request_id="req2",
|
||||
task_completed=False,
|
||||
slots_filled=2,
|
||||
slots_total=5,
|
||||
was_transferred=True,
|
||||
had_recall=False,
|
||||
latency_ms=200,
|
||||
)
|
||||
|
||||
collector.record(record1)
|
||||
collector.record(record2)
|
||||
|
||||
metrics = collector.get_metrics_snapshot("tenant1")
|
||||
|
||||
assert metrics["task_completion_rate"] == 0.5
|
||||
assert metrics["slot_completion_rate"] == 0.5
|
||||
assert metrics["wrong_transfer_rate"] == 0.5
|
||||
assert metrics["no_recall_rate"] == 0.5
|
||||
assert metrics["avg_latency_ms"] == 150.0
|
||||
|
||||
def test_get_metrics_empty_tenant(self, collector):
|
||||
"""[AC-IDMP-18] 空租户返回零值指标"""
|
||||
metrics = collector.get_metrics_snapshot("unknown_tenant")
|
||||
|
||||
assert metrics["task_completion_rate"] == 0.0
|
||||
assert metrics["slot_completion_rate"] == 0.0
|
||||
assert metrics["wrong_transfer_rate"] == 0.0
|
||||
assert metrics["no_recall_rate"] == 0.0
|
||||
assert metrics["avg_latency_ms"] == 0.0
|
||||
|
||||
|
||||
class TestSessionModeService:
|
||||
"""[AC-IDMP-09] 会话模式服务测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
return SessionModeService(mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mode_returns_default(self, service, mock_session):
|
||||
"""[AC-IDMP-09] 获取默认模式"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
mode = await service.get_mode("tenant1", "session1")
|
||||
|
||||
assert mode == "BOT_ACTIVE"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_mode_creates_new(self, service, mock_session):
|
||||
"""[AC-IDMP-09] 切换模式创建新记录"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await service.switch_mode(
|
||||
tenant_id="tenant1",
|
||||
session_id="session1",
|
||||
mode="HUMAN_ACTIVE",
|
||||
reason="user_request",
|
||||
)
|
||||
|
||||
assert result.mode == "HUMAN_ACTIVE"
|
||||
assert result.reason == "user_request"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestTraceLogger:
|
||||
"""[AC-IDMP-07] Trace 日志服务测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def logger(self, mock_session):
|
||||
return TraceLogger(mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_creates_audit_record(self, logger, mock_session):
|
||||
"""[AC-IDMP-07] 记录审计日志"""
|
||||
result = await logger.log(
|
||||
tenant_id="tenant1",
|
||||
session_id="session1",
|
||||
request_id="req1",
|
||||
generation_id="gen1",
|
||||
mode="agent",
|
||||
intent="greeting",
|
||||
tool_calls=[{"tool": "search", "duration_ms": 100}],
|
||||
guardrail_triggered=False,
|
||||
latency_ms=500,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
Loading…
Reference in New Issue