feat: add metadata discovery tool for dynamic metadata extraction [AC-METADATA-DISCOVERY]

This commit is contained in:
MerCry 2026-03-10 12:11:31 +08:00
parent 812af6c7a1
commit 3b354ba041
1 changed files with 281 additions and 0 deletions

View File

@ -0,0 +1,281 @@
"""
Metadata Discovery Tool for Mid Platform.
[AC-MARH-XX] 元数据发现工具用于查询当前可用的元数据字段及其常见值
核心特性
- 列出当前知识库文档中使用的元数据字段
- 返回每个字段的常见取值从现有文档中聚合
- 支持按知识库过滤
- 返回字段定义信息类型用途说明等
"""
from __future__ import annotations
import asyncio
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import Document, MetadataFieldDefinition, MetadataFieldStatus
from app.services.mid.timeout_governor import TimeoutGovernor
if TYPE_CHECKING:
from app.services.mid.tool_registry import ToolRegistry
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT_MS = 2000
DEFAULT_TOP_VALUES = 10
@dataclass
class MetadataFieldDiscoveryConfig:
"""Configuration for metadata field discovery tool."""
timeout_ms: int = DEFAULT_TIMEOUT_MS
top_values_count: int = DEFAULT_TOP_VALUES
@dataclass
class MetadataFieldInfo:
"""Information about a metadata field."""
field_key: str
field_type: str = "string"
label: str = ""
description: str | None = None
common_values: list[str] = field(default_factory=list)
value_count: int = 0
is_filterable: bool = True
options: list[str] | None = None
@dataclass
class MetadataDiscoveryResult:
"""Result of metadata discovery."""
success: bool
fields: list[MetadataFieldInfo] = field(default_factory=list)
total_documents: int = 0
error: str | None = None
duration_ms: int = 0
class MetadataDiscoveryTool:
"""
[AC-MARH-XX] 元数据发现工具
用于查询当前知识库文档中使用的元数据字段及其常见值
帮助 AI 了解可用的过滤字段从而更好地构造搜索请求
"""
def __init__(
self,
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: MetadataFieldDiscoveryConfig | None = None,
):
self._session = session
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._config = config or MetadataFieldDiscoveryConfig()
async def execute(
self,
tenant_id: str,
kb_id: str | None = None,
include_values: bool = True,
top_n: int | None = None,
) -> MetadataDiscoveryResult:
"""
Execute metadata discovery.
Args:
tenant_id: Tenant ID
kb_id: Optional knowledge base ID to filter
include_values: Whether to include common values (default True)
top_n: Number of top values to return per field (default from config)
Returns:
MetadataDiscoveryResult with field information
"""
start_time = asyncio.get_event_loop().time()
try:
top_n = top_n or self._config.top_values_count
field_definitions = await self._get_field_definitions(tenant_id)
document_metadata = await self._get_document_metadata(tenant_id, kb_id)
total_docs = len(document_metadata)
field_values: dict[str, Counter] = {}
for doc_meta in document_metadata:
if not doc_meta:
continue
for key, value in doc_meta.items():
if key not in field_values:
field_values[key] = Counter()
str_value = str(value) if value is not None else ""
field_values[key].update([str_value])
fields: list[MetadataFieldInfo] = []
for field_key, values_counter in field_values.items():
field_def = field_definitions.get(field_key)
common_values = []
if include_values:
most_common = values_counter.most_common(top_n)
common_values = [v for v, _ in most_common if v]
field_info = MetadataFieldInfo(
field_key=field_key,
field_type=field_def.type if field_def else "string",
label=field_def.label if field_def else field_key,
description=field_def.usage_description if field_def else None,
common_values=common_values,
value_count=len(values_counter),
is_filterable=field_def.is_filterable if field_def else True,
options=field_def.options if field_def else None,
)
fields.append(field_info)
fields.sort(key=lambda f: f.value_count, reverse=True)
duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000)
logger.info(
f"[MetadataDiscovery] Discovered {len(fields)} fields from {total_docs} documents, "
f"duration={duration_ms}ms"
)
return MetadataDiscoveryResult(
success=True,
fields=fields,
total_documents=total_docs,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"[MetadataDiscovery] Discovery failed: {e}")
return MetadataDiscoveryResult(
success=False,
error=str(e),
)
async def _get_field_definitions(
self,
tenant_id: str,
) -> dict[str, MetadataFieldDefinition]:
"""Get field definitions for tenant."""
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.tenant_id == tenant_id,
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value,
)
result = await self._session.execute(stmt)
definitions = result.scalars().all()
return {d.field_key: d for d in definitions}
async def _get_document_metadata(
self,
tenant_id: str,
kb_id: str | None = None,
) -> list[dict[str, Any]]:
"""Get all document metadata for tenant."""
stmt = select(Document.doc_metadata).where(
Document.tenant_id == tenant_id,
)
if kb_id:
stmt = stmt.where(Document.kb_id == kb_id)
result = await self._session.execute(stmt)
rows = result.scalars().all()
return [row for row in rows if row]
def register_metadata_discovery_tool(
registry: "ToolRegistry",
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: MetadataFieldDiscoveryConfig | None = None,
) -> None:
"""Register metadata discovery tool to registry."""
from app.services.mid.tool_registry import ToolType
cfg = config or MetadataFieldDiscoveryConfig()
async def metadata_discovery_handler(
tenant_id: str = "",
kb_id: str | None = None,
include_values: bool = True,
top_n: int | None = None,
**kwargs, # 接受系统注入的额外参数user_id, session_id 等)
) -> dict[str, Any]:
"""Metadata discovery tool handler."""
tool = MetadataDiscoveryTool(
session=session,
timeout_governor=timeout_governor,
config=cfg,
)
result = await tool.execute(
tenant_id=tenant_id,
kb_id=kb_id,
include_values=include_values,
top_n=top_n,
)
# 将 dataclass 转换为 dict
return {
"success": result.success,
"fields": [
{
"field_key": f.field_key,
"field_type": f.field_type,
"label": f.label,
"description": f.description,
"common_values": f.common_values,
"value_count": f.value_count,
"is_filterable": f.is_filterable,
"options": f.options,
}
for f in result.fields
],
"total_documents": result.total_documents,
"error": result.error,
"duration_ms": result.duration_ms,
}
registry.register(
name="list_document_metadata_fields",
description="列出当前知识库文档中使用的元数据字段及其常见取值,用于后续的知识库搜索过滤",
handler=metadata_discovery_handler,
tool_type=ToolType.INTERNAL,
version="1.0.0",
auth_required=False,
timeout_ms=cfg.timeout_ms,
enabled=True,
metadata={
"when_to_use": "当需要了解知识库中有哪些可用的元数据过滤字段时使用。",
"when_not_to_use": "当已知可用的过滤字段,或不需要元数据过滤时不需要调用。",
"parameters": {
"type": "object",
"properties": {
"tenant_id": {"type": "string", "description": "租户 ID"},
"kb_id": {"type": "string", "description": "知识库 ID可选用于限定范围"},
"include_values": {"type": "boolean", "description": "是否包含常见值列表,默认 true"},
"top_n": {"type": "integer", "description": "每个字段返回的常见值数量,默认 10"},
},
"required": [],
},
"example_action_input": {
"include_values": True,
"top_n": 5,
},
"result_interpretation": "fields 数组包含每个字段的详细信息common_values 是该字段在文档中的常见取值value_count 表示该字段在多少文档中出现。",
},
)
logger.info("[MetadataDiscovery] Tool registered: list_document_metadata_fields")