From 4bd2b76d1c7abf230fb0722b40d669e5120b8d8a Mon Sep 17 00:00:00 2001 From: MerCry Date: Thu, 5 Mar 2026 17:17:54 +0800 Subject: [PATCH] feat: refactor metadata_filter_builder to use RoleBasedFieldProvider [AC-MRS-11] --- .../services/mid/metadata_filter_builder.py | 284 ++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 ai-service/app/services/mid/metadata_filter_builder.py diff --git a/ai-service/app/services/mid/metadata_filter_builder.py b/ai-service/app/services/mid/metadata_filter_builder.py new file mode 100644 index 0000000..1e86678 --- /dev/null +++ b/ai-service/app/services/mid/metadata_filter_builder.py @@ -0,0 +1,284 @@ +""" +Metadata Filter Builder for KB Search Dynamic. +[AC-MARH-05] 基于元数据字段定义动态构建检索过滤器。 +[AC-MRS-11] 只消费 field_roles 包含 resource_filter 的字段 + +核心逻辑: +1. 查询状态=生效 且 field_roles 包含 resource_filter 的字段定义 +2. 根据 context 中的值构建过滤条件 +3. 必填字段缺失时返回 missing_required_slots +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy import cast, select +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import ( + MetadataFieldDefinition, + MetadataFieldStatus, + MetadataFieldType, + MetadataScope, +) +from app.models.entities import FieldRole +from app.services.mid.role_based_field_provider import RoleBasedFieldProvider + +logger = logging.getLogger(__name__) + +KB_DOCUMENT_SCOPE = MetadataScope.KB_DOCUMENT.value + + +@dataclass +class FilterBuildResult: + """过滤器构建结果。""" + applied_filter: dict[str, Any] = field(default_factory=dict) + missing_required_slots: list[dict[str, str]] = field(default_factory=list) + debug_info: dict[str, Any] = field(default_factory=dict) + success: bool = True + error_message: str | None = None + + +@dataclass +class FilterFieldInfo: + """过滤字段信息。""" + field_key: str + label: str + field_type: str + required: bool + options: list[str] | None + default_value: Any + is_filterable: bool + + +class MetadataFilterBuilder: + """ + [AC-MARH-05] 元数据过滤器构建器。 + [AC-MRS-11] 只消费 field_roles 包含 resource_filter 的字段 + + 根据租户的元数据字段定义,动态构建 KB 检索的过滤条件。 + 支持: + - 状态=生效的字段 + - field_roles 包含 resource_filter + - 字段类型决定过滤操作(单选/多选/文本) + - 必填字段缺失检测 + """ + + def __init__(self, session: AsyncSession): + self._session = session + self._role_provider = RoleBasedFieldProvider(session) + + async def build_filter( + self, + tenant_id: str, + context: dict[str, Any] | None = None, + ) -> FilterBuildResult: + """ + 构建元数据过滤器。 + + Args: + tenant_id: 租户 ID + context: 上下文信息,包含可能的过滤值 + + Returns: + FilterBuildResult 包含: + - applied_filter: 已应用的过滤条件 + - missing_required_slots: 缺失的必填字段 + - debug_info: 调试信息 + """ + context = context or {} + debug_info = { + "tenant_id": tenant_id, + "context_keys": list(context.keys()), + "filterable_fields": [], + "applied_fields": [], + } + + try: + filterable_fields = await self._get_filterable_fields(tenant_id) + debug_info["filterable_fields"] = [f.field_key for f in filterable_fields] + + if not filterable_fields: + logger.info( + f"[AC-MARH-05] No filterable fields found for tenant={tenant_id}" + ) + return FilterBuildResult( + applied_filter={}, + missing_required_slots=[], + debug_info=debug_info, + success=True, + ) + + applied_filter: dict[str, Any] = {} + missing_required_slots: list[dict[str, str]] = [] + + for field_info in filterable_fields: + field_key = field_info.field_key + value = context.get(field_key) + + if value is None and field_info.default_value is not None: + value = field_info.default_value + + if value is None: + if field_info.required: + missing_required_slots.append({ + "field_key": field_key, + "label": field_info.label, + "reason": "required_field_missing", + }) + continue + + filter_value = self._build_field_filter(field_info, value) + if filter_value is not None: + applied_filter[field_key] = filter_value + debug_info["applied_fields"].append(field_key) + + logger.info( + f"[AC-MARH-05] Filter built: tenant={tenant_id}, " + f"applied={len(applied_filter)}, missing={len(missing_required_slots)}" + ) + + return FilterBuildResult( + applied_filter=applied_filter, + missing_required_slots=missing_required_slots, + debug_info=debug_info, + success=True, + ) + + except Exception as e: + logger.error( + f"[AC-MARH-05] Filter build failed: tenant={tenant_id}, error={e}" + ) + return FilterBuildResult( + applied_filter={}, + missing_required_slots=[], + debug_info={"error": str(e)}, + success=False, + error_message=str(e), + ) + + async def _get_filterable_fields( + self, + tenant_id: str, + ) -> list[FilterFieldInfo]: + """ + [AC-MRS-11] 获取可过滤的字段定义。 + + 条件: + - 状态=生效 (active) + - field_roles 包含 resource_filter + """ + fields = await self._role_provider.get_fields_by_role( + tenant_id=tenant_id, + role=FieldRole.RESOURCE_FILTER.value, + ) + + logger.info( + f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}" + ) + + return [ + FilterFieldInfo( + field_key=f.field_key, + label=f.label, + field_type=f.type, + required=f.required, + options=f.options, + default_value=f.default_value, + is_filterable=f.is_filterable, + ) + for f in fields + ] + + def _build_field_filter( + self, + field_info: FilterFieldInfo, + value: Any, + ) -> dict[str, Any] | str | list[str] | None: + """ + 根据字段类型构建过滤条件。 + + Args: + field_info: 字段信息 + value: 字段值 + + Returns: + 过滤条件(格式取决于字段类型) + """ + field_type = field_info.field_type + + if field_type == MetadataFieldType.ENUM.value: + if field_info.options and value not in field_info.options: + logger.warning( + f"[AC-MARH-05] Invalid enum value: field={field_info.field_key}, " + f"value={value}, options={field_info.options}" + ) + return None + return {"$eq": value} + + elif field_type == MetadataFieldType.ARRAY_ENUM.value: + if not isinstance(value, list): + value = [value] + if field_info.options: + invalid = [v for v in value if v not in field_info.options] + if invalid: + logger.warning( + f"[AC-MARH-05] Invalid array_enum values: field={field_info.field_key}, " + f"invalid={invalid}" + ) + value = [v for v in value if v in field_info.options] + if not value: + return None + return {"$in": value} + + elif field_type == MetadataFieldType.NUMBER.value: + try: + num_value = float(value) + return {"$eq": num_value} + except (ValueError, TypeError): + logger.warning( + f"[AC-MARH-05] Invalid number value: field={field_info.field_key}, " + f"value={value}" + ) + return None + + elif field_type == MetadataFieldType.BOOLEAN.value: + if isinstance(value, bool): + return {"$eq": value} + if value in ["true", "1", 1, "True"]: + return {"$eq": True} + if value in ["false", "0", 0, "False"]: + return {"$eq": False} + return None + + else: + return {"$eq": str(value)} + + async def get_filter_schema( + self, + tenant_id: str, + ) -> list[dict[str, Any]]: + """ + 获取过滤字段 Schema,用于前端动态渲染或 Agent 工具描述。 + + Args: + tenant_id: 租户 ID + + Returns: + 字段 Schema 列表 + """ + filterable_fields = await self._get_filterable_fields(tenant_id) + + return [ + { + "field_key": f.field_key, + "label": f.label, + "type": f.field_type, + "required": f.required, + "options": f.options, + "default": f.default_value, + } + for f in filterable_fields + ]