ai-robot-core/ai-service/app/services/mid/slot_manager.py

380 lines
11 KiB
Python
Raw Normal View History

"""
Slot Manager Service.
槽位管理服务 - 统一槽位写入入口集成校验逻辑
职责
1. 在槽位值写入前执行校验
2. 管理槽位值的来源和置信度
3. 提供槽位写入的统一接口
4. 返回校验失败时的追问提示
"""
import logging
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import SlotDefinition
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_validation_service import (
BatchValidationResult,
SlotValidationError,
SlotValidationService,
)
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
class SlotWriteResult:
"""
槽位写入结果
Attributes:
success: 是否成功校验通过并写入
slot_key: 槽位键名
value: 最终写入的值
error: 校验错误信息校验失败时
ask_back_prompt: 追问提示语校验失败时
"""
def __init__(
self,
success: bool,
slot_key: str,
value: Any | None = None,
error: SlotValidationError | None = None,
ask_back_prompt: str | None = None,
):
self.success = success
self.slot_key = slot_key
self.value = value
self.error = error
self.ask_back_prompt = ask_back_prompt
def to_dict(self) -> dict[str, Any]:
"""转换为字典"""
result = {
"success": self.success,
"slot_key": self.slot_key,
"value": self.value,
}
if self.error:
result["error"] = {
"error_code": self.error.error_code,
"error_message": self.error.error_message,
}
if self.ask_back_prompt:
result["ask_back_prompt"] = self.ask_back_prompt
return result
class SlotManager:
"""
槽位管理器
统一槽位写入入口在写入前执行校验
支持从 SlotDefinition 加载校验规则并执行
"""
def __init__(
self,
session: AsyncSession | None = None,
validation_service: SlotValidationService | None = None,
slot_def_service: SlotDefinitionService | None = None,
):
"""
初始化槽位管理器
Args:
session: 数据库会话
validation_service: 校验服务实例
slot_def_service: 槽位定义服务实例
"""
self._session = session
self._validation_service = validation_service or SlotValidationService()
self._slot_def_service = slot_def_service
self._slot_def_cache: dict[str, SlotDefinition | None] = {}
async def write_slot(
self,
tenant_id: str,
slot_key: str,
value: Any,
source: SlotSource = SlotSource.USER_CONFIRMED,
confidence: float = 1.0,
skip_validation: bool = False,
) -> SlotWriteResult:
"""
写入单个槽位值带校验
执行流程
1. 加载槽位定义
2. 执行校验如果未跳过
3. 返回校验结果
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
value: 槽位值
source: 值来源
confidence: 置信度
skip_validation: 是否跳过校验用于特殊场景
Returns:
SlotWriteResult: 写入结果
"""
# 加载槽位定义
slot_def = await self._get_slot_definition(tenant_id, slot_key)
# 如果没有定义且非跳过校验,允许写入(动态槽位)
if slot_def is None and skip_validation:
logger.debug(
f"[SlotManager] Writing slot without definition: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
if slot_def is None:
# 未定义槽位,允许写入但记录日志
logger.info(
f"[SlotManager] Slot definition not found, allowing write: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
# 执行校验
if not skip_validation:
validation_result = self._validation_service.validate_slot_value(
slot_def, value, tenant_id
)
if not validation_result.ok:
logger.info(
f"[SlotManager] Slot validation failed: "
f"tenant_id={tenant_id}, slot_key={slot_key}, "
f"error_code={validation_result.error_code}"
)
return SlotWriteResult(
success=False,
slot_key=slot_key,
error=SlotValidationError(
slot_key=slot_key,
error_code=validation_result.error_code or "VALIDATION_FAILED",
error_message=validation_result.error_message or "校验失败",
ask_back_prompt=validation_result.ask_back_prompt,
),
ask_back_prompt=validation_result.ask_back_prompt,
)
# 使用归一化后的值
value = validation_result.normalized_value
logger.debug(
f"[SlotManager] Slot validation passed: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
async def write_slots(
self,
tenant_id: str,
values: dict[str, Any],
source: SlotSource = SlotSource.USER_CONFIRMED,
confidence: float = 1.0,
skip_validation: bool = False,
) -> BatchValidationResult:
"""
批量写入槽位值带校验
Args:
tenant_id: 租户 ID
values: 槽位值字典 {slot_key: value}
source: 值来源
confidence: 置信度
skip_validation: 是否跳过校验
Returns:
BatchValidationResult: 批量校验结果
"""
if skip_validation:
return BatchValidationResult(
ok=True,
validated_values=values,
)
# 加载所有相关槽位定义
slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys()))
# 执行批量校验
result = self._validation_service.validate_slots(
slot_defs, values, tenant_id
)
if not result.ok:
logger.info(
f"[SlotManager] Batch slot validation failed: "
f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}"
)
else:
logger.debug(
f"[SlotManager] Batch slot validation passed: "
f"tenant_id={tenant_id}, slots={list(values.keys())}"
)
return result
async def validate_before_write(
self,
tenant_id: str,
slot_key: str,
value: Any,
) -> tuple[bool, SlotValidationError | None]:
"""
在写入前预校验槽位值
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
value: 槽位值
Returns:
Tuple of (是否通过, 错误信息)
"""
slot_def = await self._get_slot_definition(tenant_id, slot_key)
if slot_def is None:
# 未定义槽位,视为通过
return True, None
result = self._validation_service.validate_slot_value(
slot_def, value, tenant_id
)
if result.ok:
return True, None
return False, SlotValidationError(
slot_key=slot_key,
error_code=result.error_code or "VALIDATION_FAILED",
error_message=result.error_message or "校验失败",
ask_back_prompt=result.ask_back_prompt,
)
async def get_ask_back_prompt(
self,
tenant_id: str,
slot_key: str,
) -> str | None:
"""
获取槽位的追问提示语
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
追问提示语或 None
"""
slot_def = await self._get_slot_definition(tenant_id, slot_key)
if slot_def is None:
return None
if isinstance(slot_def, SlotDefinition):
return slot_def.ask_back_prompt
return slot_def.get("ask_back_prompt")
async def _get_slot_definition(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | dict[str, Any] | None:
"""
获取槽位定义带缓存
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
槽位定义或 None
"""
cache_key = f"{tenant_id}:{slot_key}"
if cache_key in self._slot_def_cache:
return self._slot_def_cache[cache_key]
slot_def = None
if self._slot_def_service:
slot_def = await self._slot_def_service.get_slot_definition_by_key(
tenant_id, slot_key
)
elif self._session:
service = SlotDefinitionService(self._session)
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
self._slot_def_cache[cache_key] = slot_def
return slot_def
async def _get_slot_definitions(
self,
tenant_id: str,
slot_keys: list[str],
) -> list[SlotDefinition | dict[str, Any]]:
"""
批量获取槽位定义
Args:
tenant_id: 租户 ID
slot_keys: 槽位键名列表
Returns:
槽位定义列表
"""
slot_defs = []
for key in slot_keys:
slot_def = await self._get_slot_definition(tenant_id, key)
if slot_def:
slot_defs.append(slot_def)
return slot_defs
def clear_cache(self) -> None:
"""清除槽位定义缓存"""
self._slot_def_cache.clear()
def create_slot_manager(
session: AsyncSession | None = None,
validation_service: SlotValidationService | None = None,
slot_def_service: SlotDefinitionService | None = None,
) -> SlotManager:
"""
创建槽位管理器实例
Args:
session: 数据库会话
validation_service: 校验服务实例
slot_def_service: 槽位定义服务实例
Returns:
SlotManager: 槽位管理器实例
"""
return SlotManager(
session=session,
validation_service=validation_service,
slot_def_service=slot_def_service,
)