""" Tests for Dialogue API with Slot State Integration. [AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试 """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from app.api.mid.dialogue import _generate_ask_back_for_missing_slots from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo from app.services.mid.slot_state_aggregator import SlotState class TestGenerateAskBackResponse: """测试生成追问响应""" @pytest.mark.asyncio async def test_generate_ask_back_with_prompt(self): """测试使用配置的 ask_back_prompt 生成追问""" slot_state = SlotState() missing_slots = [ { "slot_key": "region", "label": "地区", "ask_back_prompt": "请问您在哪个地区?", } ] mock_session = AsyncMock() response = await _generate_ask_back_for_missing_slots( slot_state=slot_state, missing_slots=missing_slots, session=mock_session, tenant_id="test_tenant", ) assert response == "请问您在哪个地区?" @pytest.mark.asyncio async def test_generate_ask_back_generic(self): """测试使用通用模板生成追问""" slot_state = SlotState() missing_slots = [ { "slot_key": "product_line", "label": "产品线", # 没有 ask_back_prompt } ] mock_session = AsyncMock() response = await _generate_ask_back_for_missing_slots( slot_state=slot_state, missing_slots=missing_slots, session=mock_session, tenant_id="test_tenant", ) assert "产品线" in response @pytest.mark.asyncio async def test_generate_ask_back_empty_slots(self): """测试空缺失槽位列表""" slot_state = SlotState() missing_slots = [] mock_session = AsyncMock() response = await _generate_ask_back_for_missing_slots( slot_state=slot_state, missing_slots=missing_slots, session=mock_session, tenant_id="test_tenant", ) assert "更多信息" in response class TestDialogueAskBackResponse: """测试对话追问响应""" def test_dialogue_response_with_ask_back(self): """测试追问响应的结构""" from app.models.mid.schemas import DialogueResponse response = DialogueResponse( segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)], trace=TraceInfo( mode=ExecutionMode.AGENT, request_id="test_request_id", generation_id="test_generation_id", fallback_reason_code="missing_required_slots", kb_tool_called=True, kb_hit=False, ), ) assert len(response.segments) == 1 assert "哪个产品线" in response.segments[0].text assert response.trace.fallback_reason_code == "missing_required_slots" assert response.trace.kb_tool_called is True assert response.trace.kb_hit is False class TestSlotStateAggregationFlow: """测试槽位状态聚合流程""" @pytest.mark.asyncio async def test_memory_slots_included_in_state(self): """测试 memory_recall 的槽位被包含在状态中""" from app.models.mid.schemas import MemorySlot, SlotSource from app.services.mid.slot_state_aggregator import SlotStateAggregator mock_session = AsyncMock() aggregator = SlotStateAggregator( session=mock_session, tenant_id="test_tenant", ) memory_slots = { "product_line": MemorySlot( key="product_line", value="vip_course", source=SlotSource.USER_CONFIRMED, confidence=1.0, ) } with patch.object( aggregator._slot_def_service, "list_slot_definitions", return_value=[], ): state = await aggregator.aggregate( memory_slots=memory_slots, current_input_slots=None, context=None, ) assert "product_line" in state.filled_slots assert state.filled_slots["product_line"] == "vip_course" assert state.slot_sources["product_line"] == "user_confirmed" @pytest.mark.asyncio async def test_missing_slots_identified(self): """测试缺失的必填槽位被正确识别""" from unittest.mock import MagicMock from app.services.mid.slot_state_aggregator import SlotStateAggregator mock_session = AsyncMock() aggregator = SlotStateAggregator( session=mock_session, tenant_id="test_tenant", ) # 模拟一个 required 的槽位定义 mock_slot_def = MagicMock() mock_slot_def.slot_key = "region" mock_slot_def.required = True mock_slot_def.ask_back_prompt = "请问您在哪个地区?" mock_slot_def.linked_field_id = None with patch.object( aggregator._slot_def_service, "list_slot_definitions", return_value=[mock_slot_def], ): state = await aggregator.aggregate( memory_slots={}, current_input_slots=None, context=None, ) assert len(state.missing_required_slots) == 1 assert state.missing_required_slots[0]["slot_key"] == "region" assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?" class TestSlotMetadataLinkage: """测试槽位与元数据关联""" @pytest.mark.asyncio async def test_slot_to_field_mapping(self): """测试槽位到元数据字段的映射""" from unittest.mock import MagicMock, patch from app.services.mid.slot_state_aggregator import SlotStateAggregator from app.services.metadata_field_definition_service import MetadataFieldDefinitionService mock_session = AsyncMock() aggregator = SlotStateAggregator( session=mock_session, tenant_id="test_tenant", ) # 模拟槽位定义(带 linked_field_id) mock_slot_def = MagicMock() mock_slot_def.slot_key = "product" mock_slot_def.linked_field_id = "field-uuid-123" mock_slot_def.required = False mock_slot_def.type = "string" mock_slot_def.options = None # 模拟关联的元数据字段 mock_field = MagicMock() mock_field.field_key = "product_line" mock_field.label = "产品线" mock_field.type = "string" mock_field.required = False mock_field.options = None with patch.object( aggregator._slot_def_service, "list_slot_definitions", return_value=[mock_slot_def], ): with patch.object( MetadataFieldDefinitionService, "get_field_definition", return_value=mock_field, ): state = await aggregator.aggregate( memory_slots={}, current_input_slots=None, context=None, ) # 验证映射已建立 assert state.slot_to_field_map.get("product") == "product_line" class TestBackwardCompatibility: """测试向后兼容性""" @pytest.mark.asyncio async def test_kb_search_without_slot_state(self): """测试不使用 slot_state 时 KB 检索仍然工作""" from app.services.mid.kb_search_dynamic_tool import ( KbSearchDynamicConfig, KbSearchDynamicTool, ) from app.services.mid.metadata_filter_builder import MetadataFilterBuilder mock_session = AsyncMock() kb_tool = KbSearchDynamicTool( session=mock_session, config=KbSearchDynamicConfig(enabled=True), ) # 模拟 filter_builder 返回空结果 with patch.object( MetadataFilterBuilder, "_get_filterable_fields", return_value=[], ): with patch.object( kb_tool, "_retrieve_with_timeout", return_value=[], ): result = await kb_tool.execute( query="退款政策", tenant_id="test_tenant", context={}, slot_state=None, # 不提供 slot_state ) # 应该成功执行 assert result.success is True assert result.fallback_reason_code is None @pytest.mark.asyncio async def test_legacy_context_filter(self): """测试使用传统 context 构建过滤器""" from app.services.mid.kb_search_dynamic_tool import ( KbSearchDynamicConfig, KbSearchDynamicTool, ) from app.services.mid.metadata_filter_builder import MetadataFilterBuilder mock_session = AsyncMock() kb_tool = KbSearchDynamicTool( session=mock_session, config=KbSearchDynamicConfig(enabled=True), ) # 使用简单 context context = {"product_line": "vip_course", "region": "beijing"} with patch.object( MetadataFilterBuilder, "_get_filterable_fields", return_value=[], ): with patch.object( kb_tool, "_retrieve_with_timeout", return_value=[], ): result = await kb_tool.execute( query="退款政策", tenant_id="test_tenant", context=context, slot_state=None, ) # 应该成功执行 assert result.success is True # 简单 context 应该直接使用作为 filter assert result.applied_filter.get("product_line") == "vip_course"