From 38130f7a27d17dc77deb32ed928fe9e5b1d49230 Mon Sep 17 00:00:00 2001 From: MerCry Date: Thu, 5 Mar 2026 18:15:15 +0800 Subject: [PATCH] test: add mid-platform service tests [AC-IDMP-01~20, AC-MARH-01~12] - Add memory tool tests - Add mid-platform services tests --- ai-service/tests/test_mid_memory_tool.py | 550 +++++++++++++++++++++++ ai-service/tests/test_mid_services.py | 337 ++++++++++++++ 2 files changed, 887 insertions(+) create mode 100644 ai-service/tests/test_mid_memory_tool.py create mode 100644 ai-service/tests/test_mid_services.py diff --git a/ai-service/tests/test_mid_memory_tool.py b/ai-service/tests/test_mid_memory_tool.py new file mode 100644 index 0000000..b783dd1 --- /dev/null +++ b/ai-service/tests/test_mid_memory_tool.py @@ -0,0 +1,550 @@ +""" +Tests for Mid Platform Memory and Tool Governance services. +[AC-IDMP-13/14/15/19] 记忆与工具治理测试 + +覆盖路径: +- 成功路径:正常 recall/update,工具调用成功 +- 超时路径:recall 超时降级,工具调用超时 +- 错误路径:recall 失败降级,工具调用错误 +- 降级路径:记忆服务不可用时继续主链路 +""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime + +from app.services.mid.memory_adapter import MemoryAdapter, UserMemory +from app.services.mid.tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder +from app.services.mid.tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry +from app.models.mid.memory import ( + RecallRequest, + RecallResponse, + UpdateRequest, + MemoryProfile, + MemoryFact, + MemoryPreferences, +) +from app.models.mid.tool_trace import ( + ToolCallTrace, + ToolCallBuilder, + ToolCallStatus, + ToolType, +) +from app.models.mid.schemas import ToolCallTrace as ToolCallTraceSchema + + +class TestMemoryAdapter: + """[AC-IDMP-13/14] 记忆适配器测试""" + + @pytest.fixture + def mock_session(self): + return AsyncMock() + + @pytest.fixture + def adapter(self, mock_session): + return MemoryAdapter(mock_session) + + @pytest.mark.asyncio + async def test_recall_success(self, adapter): + """[AC-IDMP-13] 成功召回用户记忆""" + response = await adapter.recall( + user_id="user123", + session_id="session456", + tenant_id="tenant789", + ) + + assert response is not None + assert response.profile is not None + assert response.profile.grade == "初一" + assert response.profile.region == "北京" + assert len(response.facts) > 0 + assert response.preferences is not None + assert response.preferences.tone == "friendly" + + @pytest.mark.asyncio + async def test_recall_returns_context_for_prompt(self, adapter): + """[AC-IDMP-13] 召回记忆可注入 Prompt""" + response = await adapter.recall( + user_id="user123", + session_id="session456", + ) + + context = response.get_context_for_prompt() + + assert "【用户属性】" in context + assert "年级" in context + assert "【已知事实】" in context + assert "【用户偏好】" in context + + @pytest.mark.asyncio + async def test_recall_timeout_fallback(self, mock_session): + """[AC-IDMP-13] recall 超时降级,不阻断主链路""" + adapter = MemoryAdapter(mock_session, recall_timeout_ms=1) + + async def slow_recall(*args, **kwargs): + await asyncio.sleep(1) + return RecallResponse() + + with patch.object(adapter, '_recall_internal', slow_recall): + response = await adapter.recall( + user_id="user123", + session_id="session456", + ) + + assert response is not None + assert response.profile is None + assert response.facts == [] + + @pytest.mark.asyncio + async def test_recall_error_fallback(self, adapter, mock_session): + """[AC-IDMP-13] recall 错误降级,不阻断主链路""" + with patch.object(adapter, '_recall_internal', side_effect=Exception("DB error")): + response = await adapter.recall( + user_id="user123", + session_id="session456", + ) + + assert response is not None + assert response.profile is None + + @pytest.mark.asyncio + async def test_update_async(self, adapter): + """[AC-IDMP-14] 异步更新用户记忆""" + result = await adapter.update( + user_id="user123", + session_id="session456", + messages=[{"role": "user", "content": "你好"}], + summary="用户打招呼", + ) + + assert result is True + + await asyncio.sleep(0.1) + + completed = await adapter.wait_pending_updates(timeout=1.0) + assert completed >= 0 + + @pytest.mark.asyncio + async def test_update_with_summary_generation(self, adapter): + """[AC-IDMP-14] 带摘要生成的记忆更新""" + async def summary_gen(messages): + return "这是一个测试摘要" + + result = await adapter.update_with_summary_generation( + user_id="user123", + session_id="session456", + messages=[{"role": "user", "content": "你好"}], + summary_generator=summary_gen, + ) + + assert result is True + + +class TestToolCallRecorder: + """[AC-IDMP-15] 工具调用记录器测试""" + + @pytest.fixture + def recorder(self): + return ToolCallRecorder() + + def test_record_success(self, recorder): + """[AC-IDMP-15] 记录成功的工具调用""" + trace = recorder.record_success( + session_id="session123", + tool_name="kb_search", + tool_type=ToolType.INTERNAL, + duration_ms=150, + args={"query": "测试查询"}, + result={"docs": ["doc1", "doc2"]}, + registry_version="1.0.0", + auth_applied=False, + ) + + assert trace.tool_name == "kb_search" + assert trace.status == ToolCallStatus.OK + assert trace.duration_ms == 150 + assert trace.args_digest is not None + assert trace.result_digest is not None + + def test_record_timeout(self, recorder): + """[AC-IDMP-15] 记录超时的工具调用""" + trace = recorder.record_timeout( + session_id="session123", + tool_name="kb_search", + tool_type=ToolType.INTERNAL, + duration_ms=2000, + args={"query": "测试查询"}, + ) + + assert trace.status == ToolCallStatus.TIMEOUT + assert trace.error_code == "TIMEOUT" + + def test_record_error(self, recorder): + """[AC-IDMP-15] 记录错误的工具调用""" + trace = recorder.record_error( + session_id="session123", + tool_name="kb_search", + tool_type=ToolType.INTERNAL, + duration_ms=500, + error_code="DB_CONNECTION_ERROR", + error_message="Database connection failed", + ) + + assert trace.status == ToolCallStatus.ERROR + assert trace.error_code == "DB_CONNECTION_ERROR" + + def test_record_rejected(self, recorder): + """[AC-IDMP-15] 记录被拒绝的工具调用""" + trace = recorder.record_rejected( + session_id="session123", + tool_name="sensitive_tool", + tool_type=ToolType.MCP, + reason="AUTH_REQUIRED", + ) + + assert trace.status == ToolCallStatus.REJECTED + assert trace.error_code == "AUTH_REQUIRED" + + def test_get_traces(self, recorder): + """[AC-IDMP-15] 获取会话的工具调用记录""" + recorder.record_success( + session_id="session123", + tool_name="tool1", + tool_type=ToolType.INTERNAL, + duration_ms=100, + ) + recorder.record_success( + session_id="session123", + tool_name="tool2", + tool_type=ToolType.INTERNAL, + duration_ms=200, + ) + + traces = recorder.get_traces("session123") + + assert len(traces) == 2 + assert traces[0].tool_name == "tool1" + assert traces[1].tool_name == "tool2" + + def test_get_statistics(self, recorder): + """[AC-IDMP-15] 获取工具调用统计""" + recorder.record_success( + session_id="session123", + tool_name="kb_search", + tool_type=ToolType.INTERNAL, + duration_ms=100, + ) + recorder.record_timeout( + session_id="session123", + tool_name="kb_search", + tool_type=ToolType.INTERNAL, + duration_ms=2000, + ) + + stats = recorder.get_statistics("kb_search") + + assert stats["total_calls"] == 2 + assert stats["success_rate"] == 0.5 + assert stats["timeout_rate"] == 0.5 + + def test_to_trace_info_format(self, recorder): + """[AC-IDMP-15] 转换为 TraceInfo 格式""" + recorder.record_success( + session_id="session123", + tool_name="kb_search", + tool_type=ToolType.INTERNAL, + duration_ms=100, + ) + + traces = recorder.to_trace_info_format("session123") + + assert len(traces) == 1 + assert traces[0]["tool_name"] == "kb_search" + assert traces[0]["status"] == "ok" + + def test_compute_digest(self, recorder): + """[AC-IDMP-15] 敏感参数脱敏""" + long_data = {"query": "x" * 100} + digest = ToolCallTrace.compute_digest(long_data) + + assert len(digest) < 100 + assert "..." in digest or len(digest) <= 64 + + +class TestToolRegistry: + """[AC-IDMP-19] Tool Registry 治理测试""" + + @pytest.fixture + def registry(self): + return ToolRegistry() + + def test_register_tool(self, registry): + """[AC-IDMP-19] 注册工具""" + async def dummy_handler(**kwargs): + return {"result": "ok"} + + tool = registry.register( + name="test_tool", + description="测试工具", + handler=dummy_handler, + tool_type=ToolType.INTERNAL, + version="1.0.0", + auth_required=False, + timeout_ms=1000, + ) + + assert tool.name == "test_tool" + assert tool.enabled is True + assert tool.timeout_ms == 1000 + + def test_register_mcp_tool(self, registry): + """[AC-IDMP-19] 注册 MCP 工具""" + async def mcp_handler(**kwargs): + return {"result": "ok"} + + tool = registry.register( + name="mcp_tool", + description="MCP 工具", + handler=mcp_handler, + tool_type=ToolType.MCP, + ) + + assert tool.tool_type == ToolType.MCP + + def test_enable_disable_tool(self, registry): + """[AC-IDMP-19] 启停工具""" + async def handler(**kwargs): + return {} + + registry.register( + name="toggle_tool", + description="可切换工具", + handler=handler, + ) + + assert registry.get_tool("toggle_tool").enabled is True + + registry.disable_tool("toggle_tool") + assert registry.get_tool("toggle_tool").enabled is False + + registry.enable_tool("toggle_tool") + assert registry.get_tool("toggle_tool").enabled is True + + @pytest.mark.asyncio + async def test_execute_tool_success(self, registry): + """[AC-IDMP-19] 执行工具成功""" + async def handler(**kwargs): + return {"data": kwargs.get("query")} + + registry.register( + name="search", + description="搜索工具", + handler=handler, + timeout_ms=1000, + ) + + result = await registry.execute( + tool_name="search", + args={"query": "test"}, + ) + + assert result.success is True + assert result.output["data"] == "test" + + @pytest.mark.asyncio + async def test_execute_tool_timeout(self, registry): + """[AC-IDMP-19] 工具执行超时""" + async def slow_handler(**kwargs): + await asyncio.sleep(5) + return {"data": "ok"} + + registry.register( + name="slow_tool", + description="慢工具", + handler=slow_handler, + timeout_ms=100, + ) + + result = await registry.execute( + tool_name="slow_tool", + args={}, + ) + + assert result.success is False + assert "timeout" in result.error.lower() + + @pytest.mark.asyncio + async def test_execute_tool_not_found(self, registry): + """[AC-IDMP-19] 工具不存在""" + result = await registry.execute( + tool_name="nonexistent", + args={}, + ) + + assert result.success is False + assert "not found" in result.error.lower() + + @pytest.mark.asyncio + async def test_execute_tool_disabled(self, registry): + """[AC-IDMP-19] 工具已禁用""" + async def handler(**kwargs): + return {} + + registry.register( + name="disabled_tool", + description="禁用工具", + handler=handler, + enabled=False, + ) + + result = await registry.execute( + tool_name="disabled_tool", + args={}, + ) + + assert result.success is False + assert "disabled" in result.error.lower() + + @pytest.mark.asyncio + async def test_execute_tool_auth_required(self, registry): + """[AC-IDMP-19] 工具需要鉴权""" + async def handler(**kwargs): + return {} + + registry.register( + name="auth_tool", + description="需要鉴权的工具", + handler=handler, + auth_required=True, + ) + + result = await registry.execute( + tool_name="auth_tool", + args={}, + auth_context=None, + ) + + assert result.success is False + assert "auth" in result.error.lower() + + @pytest.mark.asyncio + async def test_execute_tool_with_auth_context(self, registry): + """[AC-IDMP-19] 带鉴权上下文执行工具""" + async def handler(**kwargs): + return {"authenticated": True} + + registry.register( + name="auth_tool2", + description="需要鉴权的工具", + handler=handler, + auth_required=True, + ) + + result = await registry.execute( + tool_name="auth_tool2", + args={}, + auth_context={"user_id": "user123"}, + ) + + assert result.success is True + assert result.auth_applied is True + + def test_create_trace(self, registry): + """[AC-IDMP-19] 创建工具调用追踪""" + result = ToolExecutionResult( + success=True, + output={"data": "ok"}, + duration_ms=100, + auth_applied=False, + registry_version="1.0.0", + ) + + trace = registry.create_trace( + tool_name="test_tool", + result=result, + args_digest="query=test", + ) + + assert trace.tool_name == "test_tool" + assert trace.status == ToolCallStatus.OK + + def test_get_governance_report(self, registry): + """[AC-IDMP-19] 获取治理报告""" + async def handler(**kwargs): + return {} + + registry.register(name="tool1", description="工具1", handler=handler) + registry.register(name="tool2", description="工具2", handler=handler, enabled=False) + registry.register(name="tool3", description="工具3", handler=handler, auth_required=True) + + report = registry.get_governance_report() + + assert report["total_tools"] == 3 + assert report["enabled_tools"] == 2 + assert report["disabled_tools"] == 1 + assert report["auth_required_tools"] == 1 + + def test_get_tool_registry_singleton(self): + """[AC-IDMP-19] 获取全局注册表实例""" + registry1 = get_tool_registry() + registry2 = get_tool_registry() + + assert registry1 is registry2 + + def test_init_tool_registry(self): + """[AC-IDMP-19] 初始化注册表""" + from app.services.mid.timeout_governor import TimeoutGovernor + + governor = TimeoutGovernor() + registry = init_tool_registry(timeout_governor=governor) + + assert registry is not None + assert registry._timeout_governor is governor + + +class TestToolCallBuilder: + """[AC-IDMP-15] 工具调用构建器测试""" + + def test_build_success_trace(self): + """[AC-IDMP-15] 构建成功追踪""" + builder = ToolCallBuilder( + tool_name="test_tool", + tool_type=ToolType.INTERNAL, + ) + + trace = builder.with_args({"query": "test"}).with_result({"data": "ok"}).build() + + assert trace.tool_name == "test_tool" + assert trace.status == ToolCallStatus.OK + assert trace.args_digest is not None + assert trace.result_digest is not None + + def test_build_error_trace(self): + """[AC-IDMP-15] 构建错误追踪""" + builder = ToolCallBuilder(tool_name="test_tool") + + trace = builder.with_error( + Exception("Something went wrong"), + error_code="INTERNAL_ERROR", + ).build() + + assert trace.status == ToolCallStatus.ERROR + assert trace.error_code == "INTERNAL_ERROR" + + def test_build_timeout_trace(self): + """[AC-IDMP-15] 构建超时追踪""" + builder = ToolCallBuilder(tool_name="test_tool") + + trace = builder.with_error(TimeoutError("Timeout")).build() + + assert trace.status == ToolCallStatus.TIMEOUT + + def test_build_rejected_trace(self): + """[AC-IDMP-15] 构建拒绝追踪""" + builder = ToolCallBuilder(tool_name="test_tool") + + trace = builder.with_rejected("AUTH_DENIED").build() + + assert trace.status == ToolCallStatus.REJECTED + assert trace.error_code == "AUTH_DENIED" diff --git a/ai-service/tests/test_mid_services.py b/ai-service/tests/test_mid_services.py new file mode 100644 index 0000000..be49cae --- /dev/null +++ b/ai-service/tests/test_mid_services.py @@ -0,0 +1,337 @@ +""" +Tests for Mid Platform services. +[AC-IDMP-05/07/08/09/18/20] 中台服务测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime + +from app.services.mid import ( + HighRiskHandler, + HighRiskMatchResult, + PolicyRouter, + PolicyRouteResult, + TraceLogger, + MetricsCollector, + MetricsRecord, + SessionModeService, + DEFAULT_HIGH_RISK_SCENARIOS, +) +from app.models.entities import ( + HighRiskPolicy, + HighRiskScenarioType, + SessionModeRecord, + MidAuditLog, +) +from app.models.mid import Mode, SessionMode + + +class TestHighRiskHandler: + """[AC-IDMP-05/20] 高风险场景处理器测试""" + + @pytest.fixture + def mock_session(self): + session = AsyncMock() + return session + + @pytest.fixture + def handler(self, mock_session): + return HighRiskHandler(mock_session) + + @pytest.mark.asyncio + async def test_get_active_scenario_set_returns_default_when_empty(self, handler, mock_session): + """[AC-IDMP-20] 空集保护:数据库无配置时返回默认最小集""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + scenarios = await handler.get_active_scenario_set("test_tenant") + + assert scenarios == DEFAULT_HIGH_RISK_SCENARIOS + assert "refund" in scenarios + assert "complaint_escalation" in scenarios + assert "privacy_sensitive_promise" in scenarios + assert "transfer" in scenarios + + @pytest.mark.asyncio + async def test_detect_high_risk_refund_keyword(self, handler, mock_session): + """[AC-IDMP-05] 检测退款场景""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + result = await handler.detect_high_risk("test_tenant", "我要退款") + + assert result.matched is True + assert result.scenario == HighRiskScenarioType.REFUND.value + assert result.handler_mode == "micro_flow" + + @pytest.mark.asyncio + async def test_detect_high_risk_transfer_keyword(self, handler, mock_session): + """[AC-IDMP-05] 检测转人工场景""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + result = await handler.detect_high_risk("test_tenant", "转人工") + + assert result.matched is True + assert result.scenario == HighRiskScenarioType.TRANSFER.value + assert result.handler_mode == "transfer" + + @pytest.mark.asyncio + async def test_detect_high_risk_complaint_keyword(self, handler, mock_session): + """[AC-IDMP-05] 检测投诉升级场景""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + result = await handler.detect_high_risk("test_tenant", "我要投诉你们") + + assert result.matched is True + assert result.scenario == HighRiskScenarioType.COMPLAINT_ESCALATION.value + + @pytest.mark.asyncio + async def test_detect_high_risk_privacy_keyword(self, handler, mock_session): + """[AC-IDMP-05] 检测隐私敏感承诺场景""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + result = await handler.detect_high_risk("test_tenant", "你能保证我的隐私安全吗") + + assert result.matched is True + assert result.scenario == HighRiskScenarioType.PRIVACY_SENSITIVE_PROMISE.value + + @pytest.mark.asyncio + async def test_detect_high_risk_no_match(self, handler, mock_session): + """[AC-IDMP-05] 正常消息不触发高风险""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + result = await handler.detect_high_risk("test_tenant", "你好,请问有什么可以帮助我的") + + assert result.matched is False + + @pytest.mark.asyncio + async def test_detect_high_risk_with_policy(self, handler, mock_session): + """[AC-IDMP-05] 使用数据库策略检测高风险""" + policy = HighRiskPolicy( + tenant_id="test_tenant", + scenario=HighRiskScenarioType.REFUND.value, + handler_mode="micro_flow", + keywords=["退货退款"], + priority=100, + is_enabled=True, + ) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [policy] + mock_session.execute.return_value = mock_result + + result = await handler.detect_high_risk("test_tenant", "我要退货退款") + + assert result.matched is True + assert result.scenario == HighRiskScenarioType.REFUND.value + assert result.matched_keyword == "退货退款" + + +class TestPolicyRouter: + """[AC-IDMP-05] 策略路由器测试""" + + @pytest.fixture + def mock_high_risk_handler(self): + return AsyncMock(spec=HighRiskHandler) + + @pytest.fixture + def router(self, mock_high_risk_handler): + return PolicyRouter(mock_high_risk_handler) + + @pytest.mark.asyncio + async def test_route_human_mode_returns_transfer(self, router, mock_high_risk_handler): + """[AC-IDMP-05] 人工模式返回 transfer""" + mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(matched=False) + + result = await router.route( + tenant_id="test_tenant", + user_message="你好", + session_mode="HUMAN_ACTIVE", + ) + + assert result.mode == "transfer" + + @pytest.mark.asyncio + async def test_route_high_risk_returns_micro_flow(self, router, mock_high_risk_handler): + """[AC-IDMP-05] 高风险场景返回 micro_flow 或 transfer""" + mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult( + matched=True, + scenario="refund", + handler_mode="micro_flow", + ) + + result = await router.route( + tenant_id="test_tenant", + user_message="我要退款", + session_mode="BOT_ACTIVE", + ) + + assert result.mode == "micro_flow" + assert result.high_risk_scenario == "refund" + + @pytest.mark.asyncio + async def test_route_low_confidence_returns_fixed(self, router, mock_high_risk_handler): + """[AC-IDMP-05] 低置信度返回 fixed""" + mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(matched=False) + + result = await router.route( + tenant_id="test_tenant", + user_message="你好", + session_mode="BOT_ACTIVE", + confidence=0.2, + ) + + assert result.mode == "fixed" + assert result.fallback_reason_code == "low_confidence" + + @pytest.mark.asyncio + async def test_route_normal_returns_agent(self, router, mock_high_risk_handler): + """[AC-IDMP-05] 正常场景返回 agent""" + mock_high_risk_handler.detect_high_risk.return_value = HighRiskMatchResult(matched=False) + + result = await router.route( + tenant_id="test_tenant", + user_message="你好", + session_mode="BOT_ACTIVE", + confidence=0.8, + ) + + assert result.mode == "agent" + + +class TestMetricsCollector: + """[AC-IDMP-18] 指标采集器测试""" + + @pytest.fixture + def collector(self): + return MetricsCollector() + + def test_record_and_get_metrics(self, collector): + """[AC-IDMP-18] 记录并获取指标""" + record1 = MetricsRecord( + tenant_id="tenant1", + session_id="session1", + request_id="req1", + task_completed=True, + slots_filled=3, + slots_total=5, + was_transferred=False, + had_recall=True, + latency_ms=100, + ) + + record2 = MetricsRecord( + tenant_id="tenant1", + session_id="session2", + request_id="req2", + task_completed=False, + slots_filled=2, + slots_total=5, + was_transferred=True, + had_recall=False, + latency_ms=200, + ) + + collector.record(record1) + collector.record(record2) + + metrics = collector.get_metrics_snapshot("tenant1") + + assert metrics["task_completion_rate"] == 0.5 + assert metrics["slot_completion_rate"] == 0.5 + assert metrics["wrong_transfer_rate"] == 0.5 + assert metrics["no_recall_rate"] == 0.5 + assert metrics["avg_latency_ms"] == 150.0 + + def test_get_metrics_empty_tenant(self, collector): + """[AC-IDMP-18] 空租户返回零值指标""" + metrics = collector.get_metrics_snapshot("unknown_tenant") + + assert metrics["task_completion_rate"] == 0.0 + assert metrics["slot_completion_rate"] == 0.0 + assert metrics["wrong_transfer_rate"] == 0.0 + assert metrics["no_recall_rate"] == 0.0 + assert metrics["avg_latency_ms"] == 0.0 + + +class TestSessionModeService: + """[AC-IDMP-09] 会话模式服务测试""" + + @pytest.fixture + def mock_session(self): + return AsyncMock() + + @pytest.fixture + def service(self, mock_session): + return SessionModeService(mock_session) + + @pytest.mark.asyncio + async def test_get_mode_returns_default(self, service, mock_session): + """[AC-IDMP-09] 获取默认模式""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + mode = await service.get_mode("tenant1", "session1") + + assert mode == "BOT_ACTIVE" + + @pytest.mark.asyncio + async def test_switch_mode_creates_new(self, service, mock_session): + """[AC-IDMP-09] 切换模式创建新记录""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + result = await service.switch_mode( + tenant_id="tenant1", + session_id="session1", + mode="HUMAN_ACTIVE", + reason="user_request", + ) + + assert result.mode == "HUMAN_ACTIVE" + assert result.reason == "user_request" + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +class TestTraceLogger: + """[AC-IDMP-07] Trace 日志服务测试""" + + @pytest.fixture + def mock_session(self): + return AsyncMock() + + @pytest.fixture + def logger(self, mock_session): + return TraceLogger(mock_session) + + @pytest.mark.asyncio + async def test_log_creates_audit_record(self, logger, mock_session): + """[AC-IDMP-07] 记录审计日志""" + result = await logger.log( + tenant_id="tenant1", + session_id="session1", + request_id="req1", + generation_id="gen1", + mode="agent", + intent="greeting", + tool_calls=[{"tool": "search", "duration_ms": 100}], + guardrail_triggered=False, + latency_ms=500, + ) + + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once()