ai-robot-core/ai-service/tests/test_dialogue_slot_integrat...

309 lines
10 KiB
Python
Raw Normal View History

"""
Tests for Dialogue API with Slot State Integration.
[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.api.mid.dialogue import _generate_ask_back_for_missing_slots
from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo
from app.services.mid.slot_state_aggregator import SlotState
class TestGenerateAskBackResponse:
"""测试生成追问响应"""
@pytest.mark.asyncio
async def test_generate_ask_back_with_prompt(self):
"""测试使用配置的 ask_back_prompt 生成追问"""
slot_state = SlotState()
missing_slots = [
{
"slot_key": "region",
"label": "地区",
"ask_back_prompt": "请问您在哪个地区?",
}
]
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert response == "请问您在哪个地区?"
@pytest.mark.asyncio
async def test_generate_ask_back_generic(self):
"""测试使用通用模板生成追问"""
slot_state = SlotState()
missing_slots = [
{
"slot_key": "product_line",
"label": "产品线",
# 没有 ask_back_prompt
}
]
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert "产品线" in response
@pytest.mark.asyncio
async def test_generate_ask_back_empty_slots(self):
"""测试空缺失槽位列表"""
slot_state = SlotState()
missing_slots = []
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert "更多信息" in response
class TestDialogueAskBackResponse:
"""测试对话追问响应"""
def test_dialogue_response_with_ask_back(self):
"""测试追问响应的结构"""
from app.models.mid.schemas import DialogueResponse
response = DialogueResponse(
segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id="test_request_id",
generation_id="test_generation_id",
fallback_reason_code="missing_required_slots",
kb_tool_called=True,
kb_hit=False,
),
)
assert len(response.segments) == 1
assert "哪个产品线" in response.segments[0].text
assert response.trace.fallback_reason_code == "missing_required_slots"
assert response.trace.kb_tool_called is True
assert response.trace.kb_hit is False
class TestSlotStateAggregationFlow:
"""测试槽位状态聚合流程"""
@pytest.mark.asyncio
async def test_memory_slots_included_in_state(self):
"""测试 memory_recall 的槽位被包含在状态中"""
from app.models.mid.schemas import MemorySlot, SlotSource
from app.services.mid.slot_state_aggregator import SlotStateAggregator
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
memory_slots = {
"product_line": MemorySlot(
key="product_line",
value="vip_course",
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
)
}
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[],
):
state = await aggregator.aggregate(
memory_slots=memory_slots,
current_input_slots=None,
context=None,
)
assert "product_line" in state.filled_slots
assert state.filled_slots["product_line"] == "vip_course"
assert state.slot_sources["product_line"] == "user_confirmed"
@pytest.mark.asyncio
async def test_missing_slots_identified(self):
"""测试缺失的必填槽位被正确识别"""
from unittest.mock import MagicMock
from app.services.mid.slot_state_aggregator import SlotStateAggregator
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟一个 required 的槽位定义
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "region"
mock_slot_def.required = True
mock_slot_def.ask_back_prompt = "请问您在哪个地区?"
mock_slot_def.linked_field_id = None
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
state = await aggregator.aggregate(
memory_slots={},
current_input_slots=None,
context=None,
)
assert len(state.missing_required_slots) == 1
assert state.missing_required_slots[0]["slot_key"] == "region"
assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?"
class TestSlotMetadataLinkage:
"""测试槽位与元数据关联"""
@pytest.mark.asyncio
async def test_slot_to_field_mapping(self):
"""测试槽位到元数据字段的映射"""
from unittest.mock import MagicMock, patch
from app.services.mid.slot_state_aggregator import SlotStateAggregator
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟槽位定义(带 linked_field_id
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "product"
mock_slot_def.linked_field_id = "field-uuid-123"
mock_slot_def.required = False
mock_slot_def.type = "string"
mock_slot_def.options = None
# 模拟关联的元数据字段
mock_field = MagicMock()
mock_field.field_key = "product_line"
mock_field.label = "产品线"
mock_field.type = "string"
mock_field.required = False
mock_field.options = None
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
with patch.object(
MetadataFieldDefinitionService,
"get_field_definition",
return_value=mock_field,
):
state = await aggregator.aggregate(
memory_slots={},
current_input_slots=None,
context=None,
)
# 验证映射已建立
assert state.slot_to_field_map.get("product") == "product_line"
class TestBackwardCompatibility:
"""测试向后兼容性"""
@pytest.mark.asyncio
async def test_kb_search_without_slot_state(self):
"""测试不使用 slot_state 时 KB 检索仍然工作"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(enabled=True),
)
# 模拟 filter_builder 返回空结果
with patch.object(
MetadataFilterBuilder,
"_get_filterable_fields",
return_value=[],
):
with patch.object(
kb_tool,
"_retrieve_with_timeout",
return_value=[],
):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context={},
slot_state=None, # 不提供 slot_state
)
# 应该成功执行
assert result.success is True
assert result.fallback_reason_code is None
@pytest.mark.asyncio
async def test_legacy_context_filter(self):
"""测试使用传统 context 构建过滤器"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(enabled=True),
)
# 使用简单 context
context = {"product_line": "vip_course", "region": "beijing"}
with patch.object(
MetadataFilterBuilder,
"_get_filterable_fields",
return_value=[],
):
with patch.object(
kb_tool,
"_retrieve_with_timeout",
return_value=[],
):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context=context,
slot_state=None,
)
# 应该成功执行
assert result.success is True
# 简单 context 应该直接使用作为 filter
assert result.applied_filter.get("product_line") == "vip_course"