feat: refactor metadata_filter_builder to use RoleBasedFieldProvider [AC-MRS-11]

This commit is contained in:
MerCry 2026-03-05 17:17:54 +08:00
parent 57c553ced3
commit 4bd2b76d1c
1 changed files with 284 additions and 0 deletions

View File

@ -0,0 +1,284 @@
"""
Metadata Filter Builder for KB Search Dynamic.
[AC-MARH-05] 基于元数据字段定义动态构建检索过滤器
[AC-MRS-11] 只消费 field_roles 包含 resource_filter 的字段
核心逻辑
1. 查询状态=生效 field_roles 包含 resource_filter 的字段定义
2. 根据 context 中的值构建过滤条件
3. 必填字段缺失时返回 missing_required_slots
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import (
MetadataFieldDefinition,
MetadataFieldStatus,
MetadataFieldType,
MetadataScope,
)
from app.models.entities import FieldRole
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
logger = logging.getLogger(__name__)
KB_DOCUMENT_SCOPE = MetadataScope.KB_DOCUMENT.value
@dataclass
class FilterBuildResult:
"""过滤器构建结果。"""
applied_filter: dict[str, Any] = field(default_factory=dict)
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
debug_info: dict[str, Any] = field(default_factory=dict)
success: bool = True
error_message: str | None = None
@dataclass
class FilterFieldInfo:
"""过滤字段信息。"""
field_key: str
label: str
field_type: str
required: bool
options: list[str] | None
default_value: Any
is_filterable: bool
class MetadataFilterBuilder:
"""
[AC-MARH-05] 元数据过滤器构建器
[AC-MRS-11] 只消费 field_roles 包含 resource_filter 的字段
根据租户的元数据字段定义动态构建 KB 检索的过滤条件
支持
- 状态=生效的字段
- field_roles 包含 resource_filter
- 字段类型决定过滤操作单选/多选/文本
- 必填字段缺失检测
"""
def __init__(self, session: AsyncSession):
self._session = session
self._role_provider = RoleBasedFieldProvider(session)
async def build_filter(
self,
tenant_id: str,
context: dict[str, Any] | None = None,
) -> FilterBuildResult:
"""
构建元数据过滤器
Args:
tenant_id: 租户 ID
context: 上下文信息包含可能的过滤值
Returns:
FilterBuildResult 包含:
- applied_filter: 已应用的过滤条件
- missing_required_slots: 缺失的必填字段
- debug_info: 调试信息
"""
context = context or {}
debug_info = {
"tenant_id": tenant_id,
"context_keys": list(context.keys()),
"filterable_fields": [],
"applied_fields": [],
}
try:
filterable_fields = await self._get_filterable_fields(tenant_id)
debug_info["filterable_fields"] = [f.field_key for f in filterable_fields]
if not filterable_fields:
logger.info(
f"[AC-MARH-05] No filterable fields found for tenant={tenant_id}"
)
return FilterBuildResult(
applied_filter={},
missing_required_slots=[],
debug_info=debug_info,
success=True,
)
applied_filter: dict[str, Any] = {}
missing_required_slots: list[dict[str, str]] = []
for field_info in filterable_fields:
field_key = field_info.field_key
value = context.get(field_key)
if value is None and field_info.default_value is not None:
value = field_info.default_value
if value is None:
if field_info.required:
missing_required_slots.append({
"field_key": field_key,
"label": field_info.label,
"reason": "required_field_missing",
})
continue
filter_value = self._build_field_filter(field_info, value)
if filter_value is not None:
applied_filter[field_key] = filter_value
debug_info["applied_fields"].append(field_key)
logger.info(
f"[AC-MARH-05] Filter built: tenant={tenant_id}, "
f"applied={len(applied_filter)}, missing={len(missing_required_slots)}"
)
return FilterBuildResult(
applied_filter=applied_filter,
missing_required_slots=missing_required_slots,
debug_info=debug_info,
success=True,
)
except Exception as e:
logger.error(
f"[AC-MARH-05] Filter build failed: tenant={tenant_id}, error={e}"
)
return FilterBuildResult(
applied_filter={},
missing_required_slots=[],
debug_info={"error": str(e)},
success=False,
error_message=str(e),
)
async def _get_filterable_fields(
self,
tenant_id: str,
) -> list[FilterFieldInfo]:
"""
[AC-MRS-11] 获取可过滤的字段定义
条件
- 状态=生效 (active)
- field_roles 包含 resource_filter
"""
fields = await self._role_provider.get_fields_by_role(
tenant_id=tenant_id,
role=FieldRole.RESOURCE_FILTER.value,
)
logger.info(
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}"
)
return [
FilterFieldInfo(
field_key=f.field_key,
label=f.label,
field_type=f.type,
required=f.required,
options=f.options,
default_value=f.default_value,
is_filterable=f.is_filterable,
)
for f in fields
]
def _build_field_filter(
self,
field_info: FilterFieldInfo,
value: Any,
) -> dict[str, Any] | str | list[str] | None:
"""
根据字段类型构建过滤条件
Args:
field_info: 字段信息
value: 字段值
Returns:
过滤条件格式取决于字段类型
"""
field_type = field_info.field_type
if field_type == MetadataFieldType.ENUM.value:
if field_info.options and value not in field_info.options:
logger.warning(
f"[AC-MARH-05] Invalid enum value: field={field_info.field_key}, "
f"value={value}, options={field_info.options}"
)
return None
return {"$eq": value}
elif field_type == MetadataFieldType.ARRAY_ENUM.value:
if not isinstance(value, list):
value = [value]
if field_info.options:
invalid = [v for v in value if v not in field_info.options]
if invalid:
logger.warning(
f"[AC-MARH-05] Invalid array_enum values: field={field_info.field_key}, "
f"invalid={invalid}"
)
value = [v for v in value if v in field_info.options]
if not value:
return None
return {"$in": value}
elif field_type == MetadataFieldType.NUMBER.value:
try:
num_value = float(value)
return {"$eq": num_value}
except (ValueError, TypeError):
logger.warning(
f"[AC-MARH-05] Invalid number value: field={field_info.field_key}, "
f"value={value}"
)
return None
elif field_type == MetadataFieldType.BOOLEAN.value:
if isinstance(value, bool):
return {"$eq": value}
if value in ["true", "1", 1, "True"]:
return {"$eq": True}
if value in ["false", "0", 0, "False"]:
return {"$eq": False}
return None
else:
return {"$eq": str(value)}
async def get_filter_schema(
self,
tenant_id: str,
) -> list[dict[str, Any]]:
"""
获取过滤字段 Schema用于前端动态渲染或 Agent 工具描述
Args:
tenant_id: 租户 ID
Returns:
字段 Schema 列表
"""
filterable_fields = await self._get_filterable_fields(tenant_id)
return [
{
"field_key": f.field_key,
"label": f.label,
"type": f.field_type,
"required": f.required,
"options": f.options,
"default": f.default_value,
}
for f in filterable_fields
]