feat: refactor metadata_filter_builder to use RoleBasedFieldProvider [AC-MRS-11]
This commit is contained in:
parent
57c553ced3
commit
4bd2b76d1c
|
|
@ -0,0 +1,284 @@
|
||||||
|
"""
|
||||||
|
Metadata Filter Builder for KB Search Dynamic.
|
||||||
|
[AC-MARH-05] 基于元数据字段定义动态构建检索过滤器。
|
||||||
|
[AC-MRS-11] 只消费 field_roles 包含 resource_filter 的字段
|
||||||
|
|
||||||
|
核心逻辑:
|
||||||
|
1. 查询状态=生效 且 field_roles 包含 resource_filter 的字段定义
|
||||||
|
2. 根据 context 中的值构建过滤条件
|
||||||
|
3. 必填字段缺失时返回 missing_required_slots
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import cast, select
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
MetadataFieldDefinition,
|
||||||
|
MetadataFieldStatus,
|
||||||
|
MetadataFieldType,
|
||||||
|
MetadataScope,
|
||||||
|
)
|
||||||
|
from app.models.entities import FieldRole
|
||||||
|
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KB_DOCUMENT_SCOPE = MetadataScope.KB_DOCUMENT.value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilterBuildResult:
|
||||||
|
"""过滤器构建结果。"""
|
||||||
|
applied_filter: dict[str, Any] = field(default_factory=dict)
|
||||||
|
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
debug_info: dict[str, Any] = field(default_factory=dict)
|
||||||
|
success: bool = True
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilterFieldInfo:
|
||||||
|
"""过滤字段信息。"""
|
||||||
|
field_key: str
|
||||||
|
label: str
|
||||||
|
field_type: str
|
||||||
|
required: bool
|
||||||
|
options: list[str] | None
|
||||||
|
default_value: Any
|
||||||
|
is_filterable: bool
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilterBuilder:
|
||||||
|
"""
|
||||||
|
[AC-MARH-05] 元数据过滤器构建器。
|
||||||
|
[AC-MRS-11] 只消费 field_roles 包含 resource_filter 的字段
|
||||||
|
|
||||||
|
根据租户的元数据字段定义,动态构建 KB 检索的过滤条件。
|
||||||
|
支持:
|
||||||
|
- 状态=生效的字段
|
||||||
|
- field_roles 包含 resource_filter
|
||||||
|
- 字段类型决定过滤操作(单选/多选/文本)
|
||||||
|
- 必填字段缺失检测
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self._session = session
|
||||||
|
self._role_provider = RoleBasedFieldProvider(session)
|
||||||
|
|
||||||
|
async def build_filter(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> FilterBuildResult:
|
||||||
|
"""
|
||||||
|
构建元数据过滤器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
context: 上下文信息,包含可能的过滤值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FilterBuildResult 包含:
|
||||||
|
- applied_filter: 已应用的过滤条件
|
||||||
|
- missing_required_slots: 缺失的必填字段
|
||||||
|
- debug_info: 调试信息
|
||||||
|
"""
|
||||||
|
context = context or {}
|
||||||
|
debug_info = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"context_keys": list(context.keys()),
|
||||||
|
"filterable_fields": [],
|
||||||
|
"applied_fields": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
filterable_fields = await self._get_filterable_fields(tenant_id)
|
||||||
|
debug_info["filterable_fields"] = [f.field_key for f in filterable_fields]
|
||||||
|
|
||||||
|
if not filterable_fields:
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MARH-05] No filterable fields found for tenant={tenant_id}"
|
||||||
|
)
|
||||||
|
return FilterBuildResult(
|
||||||
|
applied_filter={},
|
||||||
|
missing_required_slots=[],
|
||||||
|
debug_info=debug_info,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
applied_filter: dict[str, Any] = {}
|
||||||
|
missing_required_slots: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
for field_info in filterable_fields:
|
||||||
|
field_key = field_info.field_key
|
||||||
|
value = context.get(field_key)
|
||||||
|
|
||||||
|
if value is None and field_info.default_value is not None:
|
||||||
|
value = field_info.default_value
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
if field_info.required:
|
||||||
|
missing_required_slots.append({
|
||||||
|
"field_key": field_key,
|
||||||
|
"label": field_info.label,
|
||||||
|
"reason": "required_field_missing",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
filter_value = self._build_field_filter(field_info, value)
|
||||||
|
if filter_value is not None:
|
||||||
|
applied_filter[field_key] = filter_value
|
||||||
|
debug_info["applied_fields"].append(field_key)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MARH-05] Filter built: tenant={tenant_id}, "
|
||||||
|
f"applied={len(applied_filter)}, missing={len(missing_required_slots)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return FilterBuildResult(
|
||||||
|
applied_filter=applied_filter,
|
||||||
|
missing_required_slots=missing_required_slots,
|
||||||
|
debug_info=debug_info,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[AC-MARH-05] Filter build failed: tenant={tenant_id}, error={e}"
|
||||||
|
)
|
||||||
|
return FilterBuildResult(
|
||||||
|
applied_filter={},
|
||||||
|
missing_required_slots=[],
|
||||||
|
debug_info={"error": str(e)},
|
||||||
|
success=False,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_filterable_fields(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[FilterFieldInfo]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-11] 获取可过滤的字段定义。
|
||||||
|
|
||||||
|
条件:
|
||||||
|
- 状态=生效 (active)
|
||||||
|
- field_roles 包含 resource_filter
|
||||||
|
"""
|
||||||
|
fields = await self._role_provider.get_fields_by_role(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
role=FieldRole.RESOURCE_FILTER.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
FilterFieldInfo(
|
||||||
|
field_key=f.field_key,
|
||||||
|
label=f.label,
|
||||||
|
field_type=f.type,
|
||||||
|
required=f.required,
|
||||||
|
options=f.options,
|
||||||
|
default_value=f.default_value,
|
||||||
|
is_filterable=f.is_filterable,
|
||||||
|
)
|
||||||
|
for f in fields
|
||||||
|
]
|
||||||
|
|
||||||
|
def _build_field_filter(
|
||||||
|
self,
|
||||||
|
field_info: FilterFieldInfo,
|
||||||
|
value: Any,
|
||||||
|
) -> dict[str, Any] | str | list[str] | None:
|
||||||
|
"""
|
||||||
|
根据字段类型构建过滤条件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_info: 字段信息
|
||||||
|
value: 字段值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤条件(格式取决于字段类型)
|
||||||
|
"""
|
||||||
|
field_type = field_info.field_type
|
||||||
|
|
||||||
|
if field_type == MetadataFieldType.ENUM.value:
|
||||||
|
if field_info.options and value not in field_info.options:
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-MARH-05] Invalid enum value: field={field_info.field_key}, "
|
||||||
|
f"value={value}, options={field_info.options}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return {"$eq": value}
|
||||||
|
|
||||||
|
elif field_type == MetadataFieldType.ARRAY_ENUM.value:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
value = [value]
|
||||||
|
if field_info.options:
|
||||||
|
invalid = [v for v in value if v not in field_info.options]
|
||||||
|
if invalid:
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-MARH-05] Invalid array_enum values: field={field_info.field_key}, "
|
||||||
|
f"invalid={invalid}"
|
||||||
|
)
|
||||||
|
value = [v for v in value if v in field_info.options]
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return {"$in": value}
|
||||||
|
|
||||||
|
elif field_type == MetadataFieldType.NUMBER.value:
|
||||||
|
try:
|
||||||
|
num_value = float(value)
|
||||||
|
return {"$eq": num_value}
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-MARH-05] Invalid number value: field={field_info.field_key}, "
|
||||||
|
f"value={value}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
elif field_type == MetadataFieldType.BOOLEAN.value:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return {"$eq": value}
|
||||||
|
if value in ["true", "1", 1, "True"]:
|
||||||
|
return {"$eq": True}
|
||||||
|
if value in ["false", "0", 0, "False"]:
|
||||||
|
return {"$eq": False}
|
||||||
|
return None
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {"$eq": str(value)}
|
||||||
|
|
||||||
|
async def get_filter_schema(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取过滤字段 Schema,用于前端动态渲染或 Agent 工具描述。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段 Schema 列表
|
||||||
|
"""
|
||||||
|
filterable_fields = await self._get_filterable_fields(tenant_id)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"field_key": f.field_key,
|
||||||
|
"label": f.label,
|
||||||
|
"type": f.field_type,
|
||||||
|
"required": f.required,
|
||||||
|
"options": f.options,
|
||||||
|
"default": f.default_value,
|
||||||
|
}
|
||||||
|
for f in filterable_fields
|
||||||
|
]
|
||||||
Loading…
Reference in New Issue