350 lines
11 KiB
Python
350 lines
11 KiB
Python
"""
|
|
Metadata models for RAG optimization.
|
|
Implements structured metadata for knowledge chunks.
|
|
Reference: rag-optimization/spec.md Section 3.2
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import date, datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class RetrievalStrategy(str, Enum):
|
|
"""Retrieval strategy options."""
|
|
VECTOR_ONLY = "vector"
|
|
BM25_ONLY = "bm25"
|
|
HYBRID = "hybrid"
|
|
TWO_STAGE = "two_stage"
|
|
|
|
|
|
class ChunkMetadataModel(BaseModel):
|
|
"""Pydantic model for API serialization."""
|
|
category: str = ""
|
|
subcategory: str = ""
|
|
target_audience: list[str] = []
|
|
source_doc: str = ""
|
|
source_url: str = ""
|
|
department: str = ""
|
|
valid_from: str | None = None
|
|
valid_until: str | None = None
|
|
priority: int = 5
|
|
keywords: list[str] = []
|
|
grade: str = ""
|
|
subject: str = ""
|
|
type: str = ""
|
|
|
|
|
|
class GradeEnum(str, Enum):
|
|
"""年级枚举"""
|
|
GRADE_7 = "初一"
|
|
GRADE_8 = "初二"
|
|
GRADE_9 = "初三"
|
|
GRADE_10 = "高一"
|
|
GRADE_11 = "高二"
|
|
GRADE_12 = "高三"
|
|
|
|
|
|
class SubjectEnum(str, Enum):
|
|
"""学科枚举"""
|
|
GENERAL = "通用"
|
|
PHYSICS = "物理"
|
|
CHINESE = "语文"
|
|
MATH = "数学"
|
|
ENGLISH = "英语"
|
|
CHEMISTRY = "化学"
|
|
BIOLOGY = "生物"
|
|
|
|
|
|
class TypeEnum(str, Enum):
|
|
"""内容类型枚举"""
|
|
PAIN_POINT = "痛点"
|
|
SUBJECT_FEATURE = "学科特点"
|
|
ABILITY_REQUIREMENT = "能力要求"
|
|
COURSE_VALUE = "课程价值"
|
|
VIEWPOINT = "观点"
|
|
|
|
|
|
@dataclass
|
|
class ChunkMetadata:
|
|
"""
|
|
Metadata for knowledge chunks.
|
|
Reference: rag-optimization/spec.md Section 3.2.2
|
|
"""
|
|
category: str = ""
|
|
subcategory: str = ""
|
|
target_audience: list[str] = field(default_factory=list)
|
|
source_doc: str = ""
|
|
source_url: str = ""
|
|
department: str = ""
|
|
valid_from: date | None = None
|
|
valid_until: date | None = None
|
|
priority: int = 5
|
|
keywords: list[str] = field(default_factory=list)
|
|
grade: str = ""
|
|
subject: str = ""
|
|
type: str = ""
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary for storage."""
|
|
return {
|
|
"category": self.category,
|
|
"subcategory": self.subcategory,
|
|
"target_audience": self.target_audience,
|
|
"source_doc": self.source_doc,
|
|
"source_url": self.source_url,
|
|
"department": self.department,
|
|
"valid_from": self.valid_from.isoformat() if self.valid_from else None,
|
|
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
|
|
"priority": self.priority,
|
|
"keywords": self.keywords,
|
|
"grade": self.grade,
|
|
"subject": self.subject,
|
|
"type": self.type,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata":
|
|
"""Create from dictionary."""
|
|
return cls(
|
|
category=data.get("category", ""),
|
|
subcategory=data.get("subcategory", ""),
|
|
target_audience=data.get("target_audience", []),
|
|
source_doc=data.get("source_doc", ""),
|
|
source_url=data.get("source_url", ""),
|
|
department=data.get("department", ""),
|
|
valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None,
|
|
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
|
|
priority=data.get("priority", 5),
|
|
keywords=data.get("keywords", []),
|
|
grade=data.get("grade", ""),
|
|
subject=data.get("subject", ""),
|
|
type=data.get("type", ""),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MetadataFilter:
|
|
"""
|
|
Filter conditions for metadata-based retrieval.
|
|
Reference: rag-optimization/spec.md Section 4.1
|
|
"""
|
|
categories: list[str] | None = None
|
|
target_audiences: list[str] | None = None
|
|
departments: list[str] | None = None
|
|
valid_only: bool = True
|
|
min_priority: int | None = None
|
|
keywords: list[str] | None = None
|
|
|
|
def to_qdrant_filter(self) -> dict[str, Any] | None:
|
|
"""Convert to Qdrant filter format."""
|
|
conditions = []
|
|
|
|
if self.categories:
|
|
conditions.append({
|
|
"key": "metadata.category",
|
|
"match": {"any": self.categories}
|
|
})
|
|
|
|
if self.departments:
|
|
conditions.append({
|
|
"key": "metadata.department",
|
|
"match": {"any": self.departments}
|
|
})
|
|
|
|
if self.target_audiences:
|
|
conditions.append({
|
|
"key": "metadata.target_audience",
|
|
"match": {"any": self.target_audiences}
|
|
})
|
|
|
|
if self.valid_only:
|
|
today = date.today().isoformat()
|
|
conditions.append({
|
|
"should": [
|
|
{"key": "metadata.valid_until", "match": {"value": None}},
|
|
{"key": "metadata.valid_until", "range": {"gte": today}}
|
|
]
|
|
})
|
|
|
|
if self.min_priority is not None:
|
|
conditions.append({
|
|
"key": "metadata.priority",
|
|
"range": {"lte": self.min_priority}
|
|
})
|
|
|
|
if not conditions:
|
|
return None
|
|
|
|
if len(conditions) == 1:
|
|
return {"must": conditions}
|
|
|
|
return {"must": conditions}
|
|
|
|
|
|
class TagFilterModel(BaseModel):
|
|
"""Pydantic model for tag filter API serialization."""
|
|
fields: dict[str, str | list[str] | None] = Field(default_factory=dict, description="动态过滤字段")
|
|
|
|
|
|
@dataclass
|
|
class TagFilter:
|
|
"""
|
|
动态标签过滤器
|
|
用于在 RAG 检索时根据元数据字段进行过滤
|
|
支持任意动态字段,不再硬编码 grade/subject/type
|
|
|
|
示例:
|
|
- 教育行业: {"grade": "初一", "type": "痛点"}
|
|
- 医疗行业: {"department": "内科", "disease_type": "慢性病"}
|
|
"""
|
|
fields: dict[str, str | list[str] | None] = field(default_factory=dict)
|
|
|
|
def is_empty(self) -> bool:
|
|
"""检查过滤器是否为空"""
|
|
if not self.fields:
|
|
return True
|
|
return all(v is None or v == "" or (isinstance(v, list) and len(v) == 0) for v in self.fields.values())
|
|
|
|
def to_qdrant_filter(self) -> dict[str, Any] | None:
|
|
"""转换为 Qdrant 过滤格式"""
|
|
conditions = []
|
|
|
|
for field_name, field_value in self.fields.items():
|
|
if field_value is None or field_value == "":
|
|
continue
|
|
|
|
if isinstance(field_value, list):
|
|
if len(field_value) == 0:
|
|
continue
|
|
conditions.append({
|
|
"key": f"metadata.{field_name}",
|
|
"match": {"any": field_value}
|
|
})
|
|
else:
|
|
conditions.append({
|
|
"key": f"metadata.{field_name}",
|
|
"match": {"value": field_value}
|
|
})
|
|
|
|
if not conditions:
|
|
return None
|
|
|
|
return {"must": conditions}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> "TagFilter":
|
|
"""从字典创建"""
|
|
return cls(fields=dict(data))
|
|
|
|
@classmethod
|
|
def from_model(cls, model: TagFilterModel) -> "TagFilter":
|
|
"""从 Pydantic 模型创建"""
|
|
return cls(fields=dict(model.fields))
|
|
|
|
def merge_with_context(self, context: dict[str, Any]) -> "TagFilter":
|
|
"""
|
|
与上下文合并,支持模板变量替换
|
|
|
|
例如:
|
|
tag_filter = TagFilter(fields={"grade": "${context.grade}", "type": "痛点"})
|
|
context = {"grade": "初一"}
|
|
result = tag_filter.merge_with_context(context)
|
|
# result = TagFilter(fields={"grade": "初一", "type": "痛点"})
|
|
"""
|
|
def resolve_value(value: str | list | None) -> str | list | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, list):
|
|
return [resolve_value(v) for v in value]
|
|
if isinstance(value, str):
|
|
if value.startswith("${context.") and value.endswith("}"):
|
|
key = value[10:-1]
|
|
return context.get(key)
|
|
return value
|
|
return str(value)
|
|
|
|
resolved_fields = {}
|
|
for field_name, field_value in self.fields.items():
|
|
resolved = resolve_value(field_value)
|
|
if resolved is not None and resolved != "":
|
|
resolved_fields[field_name] = resolved
|
|
elif field_name in context:
|
|
resolved_fields[field_name] = context[field_name]
|
|
|
|
return TagFilter(fields=resolved_fields)
|
|
|
|
def get_field(self, field_name: str) -> str | list[str] | None:
|
|
"""获取指定字段的值"""
|
|
return self.fields.get(field_name)
|
|
|
|
|
|
@dataclass
|
|
class KnowledgeChunk:
|
|
"""
|
|
Knowledge chunk with multi-dimensional embeddings.
|
|
Reference: rag-optimization/spec.md Section 3.2.1
|
|
"""
|
|
chunk_id: str
|
|
document_id: str
|
|
content: str
|
|
embedding_full: list[float] = field(default_factory=list)
|
|
embedding_256: list[float] = field(default_factory=list)
|
|
embedding_512: list[float] = field(default_factory=list)
|
|
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
updated_at: datetime = field(default_factory=datetime.utcnow)
|
|
|
|
def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]:
|
|
"""Convert to Qdrant point format."""
|
|
return {
|
|
"id": point_id,
|
|
"vector": {
|
|
"full": self.embedding_full,
|
|
"dim_256": self.embedding_256,
|
|
"dim_512": self.embedding_512,
|
|
},
|
|
"payload": {
|
|
"chunk_id": self.chunk_id,
|
|
"document_id": self.document_id,
|
|
"text": self.content,
|
|
"metadata": self.metadata.to_dict(),
|
|
"created_at": self.created_at.isoformat(),
|
|
"updated_at": self.updated_at.isoformat(),
|
|
}
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RetrieveRequest:
|
|
"""
|
|
Request for knowledge retrieval.
|
|
Reference: rag-optimization/spec.md Section 4.1
|
|
"""
|
|
query: str
|
|
query_with_prefix: str = ""
|
|
top_k: int = 10
|
|
filters: MetadataFilter | None = None
|
|
strategy: RetrievalStrategy = RetrievalStrategy.HYBRID
|
|
|
|
def __post_init__(self):
|
|
if not self.query_with_prefix:
|
|
self.query_with_prefix = f"search_query:{self.query}"
|
|
|
|
|
|
@dataclass
|
|
class RetrieveResult:
|
|
"""
|
|
Result from knowledge retrieval.
|
|
Reference: rag-optimization/spec.md Section 4.1
|
|
"""
|
|
chunk_id: str
|
|
content: str
|
|
score: float
|
|
vector_score: float = 0.0
|
|
bm25_score: float = 0.0
|
|
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
|
rank: int = 0
|