ai-robot-core/ai-service/tests/test_slot_backfill_service.py

420 lines
15 KiB
Python
Raw Normal View History

"""
Tests for Slot Backfill Service.
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认测试
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_backfill_service import (
BackfillResult,
BackfillStatus,
BatchBackfillResult,
SlotBackfillService,
create_slot_backfill_service,
)
from app.services.mid.slot_manager import SlotWriteResult
from app.services.mid.slot_strategy_executor import (
StrategyChainResult,
StrategyStepResult,
)
class TestBackfillResult:
"""BackfillResult 测试"""
def test_is_success(self):
"""测试成功判断"""
result = BackfillResult(status=BackfillStatus.SUCCESS)
assert result.is_success() is True
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
assert result.is_success() is False
def test_needs_ask_back(self):
"""测试需要追问判断"""
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
assert result.needs_ask_back() is True
result = BackfillResult(status=BackfillStatus.EXTRACTION_FAILED)
assert result.needs_ask_back() is True
result = BackfillResult(status=BackfillStatus.SUCCESS)
assert result.needs_ask_back() is False
def test_needs_confirmation(self):
"""测试需要确认判断"""
result = BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION)
assert result.needs_confirmation() is True
result = BackfillResult(status=BackfillStatus.SUCCESS)
assert result.needs_confirmation() is False
def test_to_dict(self):
"""测试转换为字典"""
result = BackfillResult(
status=BackfillStatus.SUCCESS,
slot_key="region",
value="北京",
normalized_value="北京",
source="user_confirmed",
confidence=1.0,
)
d = result.to_dict()
assert d["status"] == "success"
assert d["slot_key"] == "region"
assert d["value"] == "北京"
assert d["source"] == "user_confirmed"
class TestBatchBackfillResult:
"""BatchBackfillResult 测试"""
def test_add_result(self):
"""测试添加结果"""
batch = BatchBackfillResult()
batch.add_result(BackfillResult(status=BackfillStatus.SUCCESS, slot_key="region"))
batch.add_result(BackfillResult(status=BackfillStatus.VALIDATION_FAILED, slot_key="product"))
batch.add_result(BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION, slot_key="grade"))
assert batch.success_count == 1
assert batch.failed_count == 1
assert batch.confirmation_needed_count == 1
def test_get_ask_back_prompts(self):
"""测试获取追问提示"""
batch = BatchBackfillResult()
batch.add_result(BackfillResult(
status=BackfillStatus.VALIDATION_FAILED,
ask_back_prompt="请重新输入",
))
batch.add_result(BackfillResult(
status=BackfillStatus.SUCCESS,
))
batch.add_result(BackfillResult(
status=BackfillStatus.EXTRACTION_FAILED,
ask_back_prompt="无法识别,请重试",
))
prompts = batch.get_ask_back_prompts()
assert len(prompts) == 2
assert "请重新输入" in prompts
assert "无法识别,请重试" in prompts
def test_get_confirmation_prompts(self):
"""测试获取确认提示"""
batch = BatchBackfillResult()
batch.add_result(BackfillResult(
status=BackfillStatus.NEEDS_CONFIRMATION,
confirmation_prompt="我理解您说的是「北京」,对吗?",
))
batch.add_result(BackfillResult(
status=BackfillStatus.SUCCESS,
))
prompts = batch.get_confirmation_prompts()
assert len(prompts) == 1
assert "北京" in prompts[0]
class TestSlotBackfillService:
"""SlotBackfillService 测试"""
@pytest.fixture
def mock_session(self):
"""创建 mock session"""
return AsyncMock()
@pytest.fixture
def mock_slot_manager(self):
"""创建 mock slot manager"""
manager = MagicMock()
manager.write_slot = AsyncMock()
manager.get_ask_back_prompt = AsyncMock(return_value="请提供信息")
return manager
@pytest.fixture
def service(self, mock_session, mock_slot_manager):
"""创建服务实例"""
return SlotBackfillService(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
slot_manager=mock_slot_manager,
)
def test_confidence_thresholds(self, service):
"""测试置信度阈值"""
assert service.CONFIDENCE_THRESHOLD_LOW == 0.5
assert service.CONFIDENCE_THRESHOLD_HIGH == 0.8
def test_get_source_for_strategy(self, service):
"""测试策略到来源的映射"""
assert service._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
assert service._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
assert service._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
assert service._get_source_for_strategy("unknown") == "unknown"
def test_get_confidence_for_strategy(self, service):
"""测试来源到置信度的映射"""
assert service._get_confidence_for_strategy(SlotSource.USER_CONFIRMED.value) == 1.0
assert service._get_confidence_for_strategy(SlotSource.RULE_EXTRACTED.value) == 0.9
assert service._get_confidence_for_strategy(SlotSource.LLM_INFERRED.value) == 0.7
assert service._get_confidence_for_strategy("context") == 0.5
assert service._get_confidence_for_strategy(SlotSource.DEFAULT.value) == 0.3
def test_generate_confirmation_prompt(self, service):
"""测试生成确认提示"""
prompt = service._generate_confirmation_prompt("region", "北京")
assert "北京" in prompt
assert "对吗" in prompt
@pytest.mark.asyncio
async def test_backfill_single_slot_success(self, service, mock_slot_manager):
"""测试单个槽位回填成功"""
mock_slot_manager.write_slot.return_value = SlotWriteResult(
success=True,
slot_key="region",
value="北京",
)
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.backfill_single_slot(
slot_key="region",
candidate_value="北京",
source="user_confirmed",
confidence=1.0,
)
assert result.status == BackfillStatus.SUCCESS
assert result.slot_key == "region"
assert result.normalized_value == "北京"
@pytest.mark.asyncio
async def test_backfill_single_slot_validation_failed(self, service, mock_slot_manager):
"""测试单个槽位回填校验失败"""
from app.services.mid.slot_validation_service import SlotValidationError
mock_slot_manager.write_slot.return_value = SlotWriteResult(
success=False,
slot_key="region",
error=SlotValidationError(
slot_key="region",
error_code="INVALID_VALUE",
error_message="无效的地区",
),
ask_back_prompt="请提供有效的地区",
)
result = await service.backfill_single_slot(
slot_key="region",
candidate_value="无效地区",
source="user_confirmed",
confidence=1.0,
)
assert result.status == BackfillStatus.VALIDATION_FAILED
assert result.ask_back_prompt == "请提供有效的地区"
@pytest.mark.asyncio
async def test_backfill_single_slot_low_confidence(self, service, mock_slot_manager):
"""测试低置信度槽位需要确认"""
mock_slot_manager.write_slot.return_value = SlotWriteResult(
success=True,
slot_key="region",
value="北京",
)
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.backfill_single_slot(
slot_key="region",
candidate_value="北京",
source="llm_inferred",
confidence=0.4,
)
assert result.status == BackfillStatus.NEEDS_CONFIRMATION
assert result.confirmation_prompt is not None
assert "北京" in result.confirmation_prompt
@pytest.mark.asyncio
async def test_backfill_multiple_slots(self, service, mock_slot_manager):
"""测试批量回填槽位"""
mock_slot_manager.write_slot.side_effect = [
SlotWriteResult(success=True, slot_key="region", value="北京"),
SlotWriteResult(success=True, slot_key="product", value="手机"),
SlotWriteResult(success=False, slot_key="grade", error=MagicMock()),
]
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.backfill_multiple_slots(
candidates={
"region": "北京",
"product": "手机",
"grade": "无效等级",
},
source="user_confirmed",
)
assert result.success_count == 2
assert result.failed_count == 1
@pytest.mark.asyncio
async def test_confirm_low_confidence_slot_confirmed(self, service):
"""测试确认低置信度槽位 - 用户确认"""
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.confirm_low_confidence_slot(
slot_key="region",
confirmed=True,
)
assert result.status == BackfillStatus.SUCCESS
assert result.source == SlotSource.USER_CONFIRMED.value
assert result.confidence == 1.0
@pytest.mark.asyncio
async def test_confirm_low_confidence_slot_rejected(self, service, mock_slot_manager):
"""测试确认低置信度槽位 - 用户拒绝"""
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.clear_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.confirm_low_confidence_slot(
slot_key="region",
confirmed=False,
)
assert result.status == BackfillStatus.VALIDATION_FAILED
assert result.ask_back_prompt is not None
class TestCreateSlotBackfillService:
"""create_slot_backfill_service 工厂函数测试"""
def test_create(self):
"""测试创建服务实例"""
mock_session = AsyncMock()
service = create_slot_backfill_service(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
assert isinstance(service, SlotBackfillService)
assert service._tenant_id == "tenant_1"
assert service._session_id == "session_1"
class TestBackfillFromUserResponse:
"""从用户回复回填测试"""
@pytest.fixture
def service(self):
"""创建服务实例"""
mock_session = AsyncMock()
mock_slot_def_service = AsyncMock()
service = SlotBackfillService(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
service._slot_def_service = mock_slot_def_service
return service
@pytest.mark.asyncio
async def test_backfill_from_user_response_success(self, service):
"""测试从用户回复成功提取并回填"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(service, '_extract_value') as mock_extract:
mock_extract.return_value = StrategyChainResult(
slot_key="region",
success=True,
final_value="北京",
final_strategy="rule",
)
with patch.object(service, 'backfill_single_slot') as mock_backfill:
mock_backfill.return_value = BackfillResult(
status=BackfillStatus.SUCCESS,
slot_key="region",
value="北京",
)
result = await service.backfill_from_user_response(
user_response="我想查询北京的产品",
expected_slots=["region"],
)
assert result.success_count == 1
@pytest.mark.asyncio
async def test_backfill_from_user_response_no_definition(self, service):
"""测试槽位定义不存在"""
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=None
)
result = await service.backfill_from_user_response(
user_response="我想查询北京的产品",
expected_slots=["unknown_slot"],
)
assert result.success_count == 0
assert result.failed_count == 0
@pytest.mark.asyncio
async def test_backfill_from_user_response_extraction_failed(self, service):
"""测试提取失败"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(service, '_extract_value') as mock_extract:
mock_extract.return_value = StrategyChainResult(
slot_key="region",
success=False,
)
result = await service.backfill_from_user_response(
user_response="我想查询产品",
expected_slots=["region"],
)
assert result.failed_count == 1
assert result.results[0].status == BackfillStatus.EXTRACTION_FAILED