feat: inject metadata filters and add fallback reason codes [AC-IDSMETA-18, AC-IDSMETA-19, AC-IDSMETA-20]
This commit is contained in:
parent
d3ae92dec5
commit
c4ad6eb8ce
|
|
@ -653,6 +653,7 @@ class OrchestratorService:
|
|||
"""
|
||||
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
|
||||
Step 5-6: Multi-KB retrieval with target KBs from intent matching.
|
||||
[AC-IDSMETA-19] Inject metadata filters (grade/subject/scene) from context.
|
||||
"""
|
||||
# Skip if flow or intent already handled
|
||||
if ctx.diagnostics.get("flow_handled") or ctx.diagnostics.get("intent_handled"):
|
||||
|
|
@ -678,12 +679,24 @@ class OrchestratorService:
|
|||
retrieval_ctx.metadata["target_kb_ids"] = ctx.target_kb_ids
|
||||
logger.info(f"[AC-AISVC-16] Using target_kb_ids from intent: {ctx.target_kb_ids}")
|
||||
|
||||
# [AC-IDSMETA-19] Inject metadata filters from context
|
||||
metadata_filters = await self._build_metadata_filters(ctx)
|
||||
if metadata_filters:
|
||||
retrieval_ctx.tag_filter = metadata_filters
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-19] Injected metadata filters: "
|
||||
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
|
||||
f"target_kbs={ctx.target_kb_ids}, "
|
||||
f"applied_metadata_filters={metadata_filters.fields}"
|
||||
)
|
||||
|
||||
ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx)
|
||||
|
||||
ctx.diagnostics["retrieval"] = {
|
||||
"hit_count": ctx.retrieval_result.hit_count,
|
||||
"max_score": ctx.retrieval_result.max_score,
|
||||
"is_empty": ctx.retrieval_result.is_empty,
|
||||
"applied_metadata_filters": metadata_filters.fields if metadata_filters else None,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
|
|
@ -708,6 +721,48 @@ class OrchestratorService:
|
|||
)
|
||||
ctx.diagnostics["retrieval_error"] = str(e)
|
||||
|
||||
async def _build_metadata_filters(self, ctx: GenerationContext):
|
||||
"""
|
||||
[AC-IDSMETA-19] Build metadata filters from context.
|
||||
|
||||
Sources:
|
||||
1. Intent rule metadata (if matched)
|
||||
2. Session metadata
|
||||
3. Request metadata
|
||||
4. Extracted slots from conversation
|
||||
|
||||
Returns:
|
||||
TagFilter with at least grade, subject, scene if available
|
||||
"""
|
||||
from app.services.retrieval.metadata import TagFilter
|
||||
|
||||
filter_fields = {}
|
||||
|
||||
# 1. From intent rule metadata
|
||||
if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_:
|
||||
intent_metadata = ctx.intent_match.rule.metadata_
|
||||
for key in ['grade', 'subject', 'scene']:
|
||||
if key in intent_metadata:
|
||||
filter_fields[key] = intent_metadata[key]
|
||||
|
||||
# 2. From session/request metadata
|
||||
if ctx.request_metadata:
|
||||
for key in ['grade', 'subject', 'scene']:
|
||||
if key in ctx.request_metadata and key not in filter_fields:
|
||||
filter_fields[key] = ctx.request_metadata[key]
|
||||
|
||||
# 3. From merged context (extracted slots)
|
||||
if ctx.merged_context and hasattr(ctx.merged_context, 'slots'):
|
||||
slots = ctx.merged_context.slots or {}
|
||||
for key in ['grade', 'subject', 'scene']:
|
||||
if key in slots and key not in filter_fields:
|
||||
filter_fields[key] = slots[key]
|
||||
|
||||
if not filter_fields:
|
||||
return None
|
||||
|
||||
return TagFilter(fields=filter_fields)
|
||||
|
||||
async def _build_system_prompt(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-56, AC-AISVC-84] Step 7: Build system prompt with template + behavior rules.
|
||||
|
|
@ -919,17 +974,66 @@ class OrchestratorService:
|
|||
def _fallback_response(self, ctx: GenerationContext) -> str:
|
||||
"""
|
||||
[AC-AISVC-17] Generate fallback response when LLM is unavailable.
|
||||
[AC-IDSMETA-20] Return fallback with structured reason code when no recall.
|
||||
"""
|
||||
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
|
||||
return (
|
||||
"根据知识库信息,我找到了一些相关内容,"
|
||||
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
|
||||
)
|
||||
|
||||
# [AC-IDSMETA-20] Record structured fallback reason code
|
||||
fallback_reason_code = self._determine_fallback_reason_code(ctx)
|
||||
ctx.diagnostics["fallback_reason_code"] = fallback_reason_code
|
||||
|
||||
logger.warning(
|
||||
f"[AC-IDSMETA-20] No recall, using fallback: "
|
||||
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
|
||||
f"target_kbs={ctx.target_kb_ids}, "
|
||||
f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, "
|
||||
f"fallback_reason_code={fallback_reason_code}"
|
||||
)
|
||||
|
||||
return (
|
||||
"抱歉,我暂时无法处理您的请求。"
|
||||
"请稍后重试或联系人工客服获取帮助。"
|
||||
)
|
||||
|
||||
def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str:
|
||||
"""
|
||||
[AC-IDSMETA-20] Determine structured fallback reason code.
|
||||
|
||||
Reason codes:
|
||||
- no_recall_after_metadata_filter: No results after applying metadata filters
|
||||
- no_recall_no_kb: No target knowledge bases configured
|
||||
- no_recall_kb_empty: Knowledge base is empty
|
||||
- no_recall_low_score: Results found but below threshold
|
||||
- no_recall_error: Retrieval error occurred
|
||||
"""
|
||||
retrieval_diag = ctx.diagnostics.get("retrieval", {})
|
||||
|
||||
# Check for retrieval error
|
||||
if ctx.diagnostics.get("retrieval_error"):
|
||||
return "no_recall_error"
|
||||
|
||||
# Check if metadata filters were applied
|
||||
if retrieval_diag.get("applied_metadata_filters"):
|
||||
return "no_recall_after_metadata_filter"
|
||||
|
||||
# Check if target KBs were configured
|
||||
if not ctx.target_kb_ids:
|
||||
return "no_recall_no_kb"
|
||||
|
||||
# Check if KB is empty (no candidates at all)
|
||||
if retrieval_diag.get("total_candidates", 0) == 0:
|
||||
return "no_recall_kb_empty"
|
||||
|
||||
# Results found but filtered out by score threshold
|
||||
if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0:
|
||||
return "no_recall_low_score"
|
||||
|
||||
return "no_recall_unknown"
|
||||
|
||||
def _calculate_confidence(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ Retrieval layer for AI Service.
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.retrieval.metadata import TagFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,6 +26,14 @@ class RetrievalContext:
|
|||
session_id: str | None = None
|
||||
channel_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
tag_filter: "TagFilter | None" = None
|
||||
kb_ids: list[str] | None = None
|
||||
|
||||
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
|
||||
"""获取标签过滤器的字典表示"""
|
||||
if self.tag_filter and not self.tag_filter.is_empty():
|
||||
return self.tag_filter.fields
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from datetime import date, datetime
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RetrievalStrategy(str, Enum):
|
||||
|
|
@ -32,6 +32,39 @@ class ChunkMetadataModel(BaseModel):
|
|||
valid_until: str | None = None
|
||||
priority: int = 5
|
||||
keywords: list[str] = []
|
||||
grade: str = ""
|
||||
subject: str = ""
|
||||
type: str = ""
|
||||
|
||||
|
||||
class GradeEnum(str, Enum):
|
||||
"""年级枚举"""
|
||||
GRADE_7 = "初一"
|
||||
GRADE_8 = "初二"
|
||||
GRADE_9 = "初三"
|
||||
GRADE_10 = "高一"
|
||||
GRADE_11 = "高二"
|
||||
GRADE_12 = "高三"
|
||||
|
||||
|
||||
class SubjectEnum(str, Enum):
|
||||
"""学科枚举"""
|
||||
GENERAL = "通用"
|
||||
PHYSICS = "物理"
|
||||
CHINESE = "语文"
|
||||
MATH = "数学"
|
||||
ENGLISH = "英语"
|
||||
CHEMISTRY = "化学"
|
||||
BIOLOGY = "生物"
|
||||
|
||||
|
||||
class TypeEnum(str, Enum):
|
||||
"""内容类型枚举"""
|
||||
PAIN_POINT = "痛点"
|
||||
SUBJECT_FEATURE = "学科特点"
|
||||
ABILITY_REQUIREMENT = "能力要求"
|
||||
COURSE_VALUE = "课程价值"
|
||||
VIEWPOINT = "观点"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -50,6 +83,9 @@ class ChunkMetadata:
|
|||
valid_until: date | None = None
|
||||
priority: int = 5
|
||||
keywords: list[str] = field(default_factory=list)
|
||||
grade: str = ""
|
||||
subject: str = ""
|
||||
type: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for storage."""
|
||||
|
|
@ -64,6 +100,9 @@ class ChunkMetadata:
|
|||
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
|
||||
"priority": self.priority,
|
||||
"keywords": self.keywords,
|
||||
"grade": self.grade,
|
||||
"subject": self.subject,
|
||||
"type": self.type,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -80,6 +119,9 @@ class ChunkMetadata:
|
|||
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
|
||||
priority=data.get("priority", 5),
|
||||
keywords=data.get("keywords", []),
|
||||
grade=data.get("grade", ""),
|
||||
subject=data.get("subject", ""),
|
||||
type=data.get("type", ""),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -142,6 +184,103 @@ class MetadataFilter:
|
|||
return {"must": conditions}
|
||||
|
||||
|
||||
class TagFilterModel(BaseModel):
|
||||
"""Pydantic model for tag filter API serialization."""
|
||||
fields: dict[str, str | list[str] | None] = Field(default_factory=dict, description="动态过滤字段")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TagFilter:
|
||||
"""
|
||||
动态标签过滤器
|
||||
用于在 RAG 检索时根据元数据字段进行过滤
|
||||
支持任意动态字段,不再硬编码 grade/subject/type
|
||||
|
||||
示例:
|
||||
- 教育行业: {"grade": "初一", "type": "痛点"}
|
||||
- 医疗行业: {"department": "内科", "disease_type": "慢性病"}
|
||||
"""
|
||||
fields: dict[str, str | list[str] | None] = field(default_factory=dict)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""检查过滤器是否为空"""
|
||||
if not self.fields:
|
||||
return True
|
||||
return all(v is None or v == "" or (isinstance(v, list) and len(v) == 0) for v in self.fields.values())
|
||||
|
||||
def to_qdrant_filter(self) -> dict[str, Any] | None:
|
||||
"""转换为 Qdrant 过滤格式"""
|
||||
conditions = []
|
||||
|
||||
for field_name, field_value in self.fields.items():
|
||||
if field_value is None or field_value == "":
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list):
|
||||
if len(field_value) == 0:
|
||||
continue
|
||||
conditions.append({
|
||||
"key": f"metadata.{field_name}",
|
||||
"match": {"any": field_value}
|
||||
})
|
||||
else:
|
||||
conditions.append({
|
||||
"key": f"metadata.{field_name}",
|
||||
"match": {"value": field_value}
|
||||
})
|
||||
|
||||
if not conditions:
|
||||
return None
|
||||
|
||||
return {"must": conditions}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TagFilter":
|
||||
"""从字典创建"""
|
||||
return cls(fields=dict(data))
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: TagFilterModel) -> "TagFilter":
|
||||
"""从 Pydantic 模型创建"""
|
||||
return cls(fields=dict(model.fields))
|
||||
|
||||
def merge_with_context(self, context: dict[str, Any]) -> "TagFilter":
|
||||
"""
|
||||
与上下文合并,支持模板变量替换
|
||||
|
||||
例如:
|
||||
tag_filter = TagFilter(fields={"grade": "${context.grade}", "type": "痛点"})
|
||||
context = {"grade": "初一"}
|
||||
result = tag_filter.merge_with_context(context)
|
||||
# result = TagFilter(fields={"grade": "初一", "type": "痛点"})
|
||||
"""
|
||||
def resolve_value(value: str | list | None) -> str | list | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return [resolve_value(v) for v in value]
|
||||
if isinstance(value, str):
|
||||
if value.startswith("${context.") and value.endswith("}"):
|
||||
key = value[10:-1]
|
||||
return context.get(key)
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
resolved_fields = {}
|
||||
for field_name, field_value in self.fields.items():
|
||||
resolved = resolve_value(field_value)
|
||||
if resolved is not None and resolved != "":
|
||||
resolved_fields[field_name] = resolved
|
||||
elif field_name in context:
|
||||
resolved_fields[field_name] = context[field_name]
|
||||
|
||||
return TagFilter(fields=resolved_fields)
|
||||
|
||||
def get_field(self, field_name: str) -> str | list[str] | None:
|
||||
"""获取指定字段的值"""
|
||||
return self.fields.get(field_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeChunk:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue