feat: refactor memory_recall_tool to only consume slot role fields [AC-MRS-12]
This commit is contained in:
parent
4bd2b76d1c
commit
f9fe6ec615
|
|
@ -0,0 +1,582 @@
|
|||
"""
|
||||
Memory Recall Tool for Mid Platform.
|
||||
[AC-IDMP-13] 记忆召回工具 - 短期可用记忆注入
|
||||
[AC-MRS-12] 只消费 field_roles 包含 slot 的字段
|
||||
|
||||
定位:短期可用记忆注入,不是完整中长期记忆系统。
|
||||
功能:读取可用记忆包(profile/facts/preferences/last_summary/slots)
|
||||
|
||||
关键特性:
|
||||
1. 优先读取已有结构化记忆
|
||||
2. 若缺失,使用最近窗口历史做最小回填
|
||||
3. 槽位冲突优先级:user_confirmed > rule_extracted > llm_inferred > default
|
||||
4. 超时 <= 1000ms,失败不抛硬异常
|
||||
5. 多租户隔离正确
|
||||
6. 只消费 slot 角色的字段
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
MemoryRecallResult,
|
||||
MemorySlot,
|
||||
SlotSource,
|
||||
ToolCallStatus,
|
||||
ToolCallTrace,
|
||||
ToolType,
|
||||
)
|
||||
from app.models.entities import FieldRole
|
||||
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RECALL_TIMEOUT_MS = 1000
|
||||
DEFAULT_MAX_RECENT_MESSAGES = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryRecallConfig:
|
||||
"""记忆召回工具配置。"""
|
||||
enabled: bool = True
|
||||
timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS
|
||||
max_recent_messages: int = DEFAULT_MAX_RECENT_MESSAGES
|
||||
default_recall_scope: list[str] = field(
|
||||
default_factory=lambda: ["profile", "facts", "preferences", "summary", "slots"]
|
||||
)
|
||||
|
||||
|
||||
class MemoryRecallTool:
|
||||
"""
|
||||
[AC-IDMP-13] 记忆召回工具。
|
||||
[AC-MRS-12] 只消费 field_roles 包含 slot 的字段
|
||||
|
||||
用于在对话前读取用户可用记忆,减少重复追问。
|
||||
|
||||
Features:
|
||||
- 读取 profile/facts/preferences/last_summary/slots
|
||||
- 槽位冲突优先级处理
|
||||
- 超时控制与降级
|
||||
- 多租户隔离
|
||||
- 只消费 slot 角色的字段
|
||||
"""
|
||||
|
||||
SLOT_PRIORITY: dict[SlotSource, int] = {
|
||||
SlotSource.USER_CONFIRMED: 4,
|
||||
SlotSource.RULE_EXTRACTED: 3,
|
||||
SlotSource.LLM_INFERRED: 2,
|
||||
SlotSource.DEFAULT: 1,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: MemoryRecallConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or MemoryRecallConfig()
|
||||
self._role_provider = RoleBasedFieldProvider(session)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
recall_scope: list[str] | None = None,
|
||||
max_recent_messages: int | None = None,
|
||||
) -> MemoryRecallResult:
|
||||
"""
|
||||
[AC-IDMP-13] 执行记忆召回。
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
user_id: 用户 ID
|
||||
session_id: 会话 ID
|
||||
recall_scope: 召回范围,默认 ["profile","facts","preferences","summary","slots"]
|
||||
max_recent_messages: 最大最近消息数,默认 8
|
||||
|
||||
Returns:
|
||||
MemoryRecallResult: 记忆召回结果
|
||||
"""
|
||||
if not self._config.enabled:
|
||||
logger.info(f"[AC-IDMP-13] Memory recall disabled for tenant={tenant_id}")
|
||||
return MemoryRecallResult(
|
||||
fallback_reason_code="MEMORY_RECALL_DISABLED",
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
scope = recall_scope or self._config.default_recall_scope
|
||||
max_msgs = max_recent_messages or self._config.max_recent_messages
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-13] Starting memory recall: tenant={tenant_id}, "
|
||||
f"user={user_id}, session={session_id}, scope={scope}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._recall_internal(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
scope=scope,
|
||||
max_recent_messages=max_msgs,
|
||||
),
|
||||
timeout=self._config.timeout_ms / 1000.0,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
result.duration_ms = duration_ms
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-13] Memory recall completed: tenant={tenant_id}, "
|
||||
f"user={user_id}, duration_ms={duration_ms}, "
|
||||
f"profile={bool(result.profile)}, facts={len(result.facts)}, "
|
||||
f"slots={len(result.slots)}, missing_slots={len(result.missing_slots)}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-IDMP-13] Memory recall timeout: tenant={tenant_id}, "
|
||||
f"user={user_id}, duration_ms={duration_ms}"
|
||||
)
|
||||
return MemoryRecallResult(
|
||||
fallback_reason_code="MEMORY_RECALL_TIMEOUT",
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-IDMP-13] Memory recall failed: tenant={tenant_id}, "
|
||||
f"user={user_id}, error={e}"
|
||||
)
|
||||
return MemoryRecallResult(
|
||||
fallback_reason_code=f"MEMORY_RECALL_ERROR:{str(e)[:50]}",
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
async def _recall_internal(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
scope: list[str],
|
||||
max_recent_messages: int,
|
||||
) -> MemoryRecallResult:
|
||||
"""内部召回实现。"""
|
||||
profile: dict[str, Any] = {}
|
||||
facts: list[str] = []
|
||||
preferences: dict[str, Any] = {}
|
||||
last_summary: str | None = None
|
||||
slots: dict[str, MemorySlot] = {}
|
||||
missing_slots: list[str] = []
|
||||
|
||||
if "profile" in scope:
|
||||
profile = await self._recall_profile(tenant_id, user_id)
|
||||
|
||||
if "facts" in scope:
|
||||
facts = await self._recall_facts(tenant_id, user_id)
|
||||
|
||||
if "preferences" in scope:
|
||||
preferences = await self._recall_preferences(tenant_id, user_id)
|
||||
|
||||
if "summary" in scope:
|
||||
last_summary = await self._recall_last_summary(tenant_id, user_id)
|
||||
|
||||
if "slots" in scope:
|
||||
slots, missing_slots = await self._recall_slots(
|
||||
tenant_id, user_id, session_id
|
||||
)
|
||||
|
||||
if not profile and not facts and not preferences and not last_summary and not slots:
|
||||
if "history" in scope or max_recent_messages > 0:
|
||||
history_facts = await self._recall_from_history(
|
||||
tenant_id, session_id, max_recent_messages
|
||||
)
|
||||
facts.extend(history_facts)
|
||||
|
||||
return MemoryRecallResult(
|
||||
profile=profile,
|
||||
facts=facts,
|
||||
preferences=preferences,
|
||||
last_summary=last_summary,
|
||||
slots=slots,
|
||||
missing_slots=missing_slots,
|
||||
)
|
||||
|
||||
async def _recall_profile(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""召回用户基础属性。"""
|
||||
try:
|
||||
from app.models.entities import ChatMessage
|
||||
from sqlmodel import col
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "user",
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).desc())
|
||||
.limit(5)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
profile: dict[str, Any] = {}
|
||||
for msg in messages:
|
||||
content = msg.content.lower()
|
||||
if "年级" in content or "初" in content or "高" in content:
|
||||
if "grade" not in profile:
|
||||
profile["grade"] = self._extract_grade(msg.content)
|
||||
if "北京" in content or "上海" in content or "广州" in content:
|
||||
if "region" not in profile:
|
||||
profile["region"] = self._extract_region(msg.content)
|
||||
|
||||
return profile
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall profile: {e}")
|
||||
return {}
|
||||
|
||||
async def _recall_facts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
) -> list[str]:
|
||||
"""召回用户事实记忆。"""
|
||||
try:
|
||||
from app.models.entities import ChatMessage
|
||||
from sqlmodel import col
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).desc())
|
||||
.limit(10)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
facts: list[str] = []
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if "已购" in content or "购买" in content:
|
||||
facts.append(self._extract_purchase_info(content))
|
||||
if "订单" in content:
|
||||
facts.append(self._extract_order_info(content))
|
||||
|
||||
return [f for f in facts if f]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall facts: {e}")
|
||||
return []
|
||||
|
||||
async def _recall_preferences(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""召回用户偏好。"""
|
||||
try:
|
||||
from app.models.entities import ChatMessage
|
||||
from sqlmodel import col
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "user",
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).desc())
|
||||
.limit(10)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
preferences: dict[str, Any] = {}
|
||||
for msg in messages:
|
||||
content = msg.content.lower()
|
||||
if "详细" in content or "详细解释" in content:
|
||||
preferences["communication_style"] = "详细解释"
|
||||
elif "简单" in content or "简洁" in content:
|
||||
preferences["communication_style"] = "简洁"
|
||||
|
||||
return preferences
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall preferences: {e}")
|
||||
return {}
|
||||
|
||||
async def _recall_last_summary(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
) -> str | None:
|
||||
"""召回最近会话摘要。"""
|
||||
try:
|
||||
from app.models.entities import MidAuditLog
|
||||
from sqlmodel import col
|
||||
|
||||
stmt = (
|
||||
select(MidAuditLog)
|
||||
.where(
|
||||
MidAuditLog.tenant_id == tenant_id,
|
||||
)
|
||||
.order_by(col(MidAuditLog.created_at).desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
audit = result.scalar_one_or_none()
|
||||
|
||||
if audit:
|
||||
return f"上次会话模式: {audit.mode}"
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}")
|
||||
return None
|
||||
|
||||
async def _recall_slots(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> tuple[dict[str, MemorySlot], list[str]]:
|
||||
"""
|
||||
[AC-MRS-12] 召回结构化槽位,只消费 slot 角色的字段。
|
||||
|
||||
Returns:
|
||||
Tuple of (slots_dict, missing_required_slots)
|
||||
"""
|
||||
try:
|
||||
slot_field_keys = await self._role_provider.get_slot_field_keys(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-12] Retrieved {len(slot_field_keys)} slot fields for tenant={tenant_id}: {slot_field_keys}"
|
||||
)
|
||||
|
||||
from app.models.entities import FlowInstance
|
||||
from sqlalchemy import desc
|
||||
|
||||
stmt = (
|
||||
select(FlowInstance)
|
||||
.where(
|
||||
FlowInstance.tenant_id == tenant_id,
|
||||
)
|
||||
.order_by(desc(FlowInstance.updated_at))
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
flow_instance = result.scalar_one_or_none()
|
||||
|
||||
slots: dict[str, MemorySlot] = {}
|
||||
missing_slots: list[str] = []
|
||||
|
||||
if flow_instance and flow_instance.context:
|
||||
context = flow_instance.context
|
||||
for key, value in context.items():
|
||||
if key in slot_field_keys and value is not None:
|
||||
slots[key] = MemorySlot(
|
||||
key=key,
|
||||
value=value,
|
||||
source=SlotSource.USER_CONFIRMED,
|
||||
confidence=1.0,
|
||||
updated_at=str(flow_instance.updated_at),
|
||||
)
|
||||
|
||||
return slots, missing_slots
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall slots: {e}")
|
||||
return {}, []
|
||||
|
||||
async def _recall_from_history(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
max_messages: int,
|
||||
) -> list[str]:
|
||||
"""从最近历史中提取最小回填信息。"""
|
||||
try:
|
||||
from app.models.entities import ChatMessage
|
||||
from sqlmodel import col
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).desc())
|
||||
.limit(max_messages)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
facts: list[str] = []
|
||||
for msg in messages:
|
||||
if msg.role == "user":
|
||||
facts.append(f"用户说过: {msg.content[:50]}")
|
||||
|
||||
return facts
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall from history: {e}")
|
||||
return []
|
||||
|
||||
def _extract_grade(self, content: str) -> str:
|
||||
"""从内容中提取年级信息。"""
|
||||
grades = ["初一", "初二", "初三", "高一", "高二", "高三"]
|
||||
for grade in grades:
|
||||
if grade in content:
|
||||
return grade
|
||||
return "未知年级"
|
||||
|
||||
def _extract_region(self, content: str) -> str:
|
||||
"""从内容中提取地区信息。"""
|
||||
regions = ["北京", "上海", "广州", "深圳", "杭州", "成都", "武汉", "南京"]
|
||||
for region in regions:
|
||||
if region in content:
|
||||
return region
|
||||
return "未知地区"
|
||||
|
||||
def _extract_purchase_info(self, content: str) -> str:
|
||||
"""从内容中提取购买信息。"""
|
||||
return f"购买记录: {content[:30]}..."
|
||||
|
||||
def _extract_order_info(self, content: str) -> str:
|
||||
"""从内容中提取订单信息。"""
|
||||
return f"订单信息: {content[:30]}..."
|
||||
|
||||
def merge_slots(
|
||||
self,
|
||||
existing_slots: dict[str, MemorySlot],
|
||||
new_slots: dict[str, MemorySlot],
|
||||
) -> dict[str, MemorySlot]:
|
||||
"""
|
||||
合并槽位,按优先级处理冲突。
|
||||
|
||||
优先级:user_confirmed > rule_extracted > llm_inferred > default
|
||||
"""
|
||||
merged = dict(existing_slots)
|
||||
|
||||
for key, new_slot in new_slots.items():
|
||||
if key not in merged:
|
||||
merged[key] = new_slot
|
||||
else:
|
||||
existing_slot = merged[key]
|
||||
existing_priority = self.SLOT_PRIORITY.get(existing_slot.source, 0)
|
||||
new_priority = self.SLOT_PRIORITY.get(new_slot.source, 0)
|
||||
|
||||
if new_priority > existing_priority:
|
||||
merged[key] = new_slot
|
||||
elif new_priority == existing_priority:
|
||||
if new_slot.confidence > existing_slot.confidence:
|
||||
merged[key] = new_slot
|
||||
|
||||
return merged
|
||||
|
||||
def create_trace(
|
||||
self,
|
||||
result: MemoryRecallResult,
|
||||
tenant_id: str,
|
||||
) -> ToolCallTrace:
|
||||
"""创建工具调用追踪记录。"""
|
||||
status = ToolCallStatus.OK
|
||||
error_code = None
|
||||
|
||||
if result.fallback_reason_code:
|
||||
if "TIMEOUT" in result.fallback_reason_code:
|
||||
status = ToolCallStatus.TIMEOUT
|
||||
else:
|
||||
status = ToolCallStatus.ERROR
|
||||
error_code = result.fallback_reason_code
|
||||
|
||||
return ToolCallTrace(
|
||||
tool_name="memory_recall",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=result.duration_ms,
|
||||
status=status,
|
||||
error_code=error_code,
|
||||
args_digest=f"tenant={tenant_id}",
|
||||
result_digest=f"profile={len(result.profile)}, facts={len(result.facts)}, slots={len(result.slots)}",
|
||||
)
|
||||
|
||||
|
||||
def register_memory_recall_tool(
|
||||
registry: Any,
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: MemoryRecallConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-IDMP-13] 注册 memory_recall 工具到 ToolRegistry。
|
||||
|
||||
Args:
|
||||
registry: ToolRegistry 实例
|
||||
session: 数据库会话
|
||||
timeout_governor: 超时治理器
|
||||
config: 工具配置
|
||||
"""
|
||||
cfg = config or MemoryRecallConfig()
|
||||
|
||||
async def memory_recall_handler(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
recall_scope: list[str] | None = None,
|
||||
max_recent_messages: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""memory_recall 工具处理器。"""
|
||||
tool = MemoryRecallTool(
|
||||
session=session,
|
||||
timeout_governor=timeout_governor,
|
||||
config=cfg,
|
||||
)
|
||||
result = await tool.execute(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
recall_scope=recall_scope,
|
||||
max_recent_messages=max_recent_messages,
|
||||
)
|
||||
return result.model_dump()
|
||||
|
||||
registry.register(
|
||||
name="memory_recall",
|
||||
description="[AC-IDMP-13] 记忆召回工具,读取用户可用记忆包(profile/facts/preferences/summary/slots)",
|
||||
handler=memory_recall_handler,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
auth_required=False,
|
||||
timeout_ms=min(cfg.timeout_ms, 1000),
|
||||
enabled=True,
|
||||
metadata={
|
||||
"ac_ids": ["AC-IDMP-13"],
|
||||
"recall_scope": cfg.default_recall_scope,
|
||||
"max_recent_messages": cfg.max_recent_messages,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
|
||||
Loading…
Reference in New Issue