ai-robot-core/ai-service/test_kb_metadata_search.py

167 lines
5.5 KiB
Python
Raw Permalink Normal View History

"""
测试KB元数据过滤查询
"""
import asyncio
import json
from app.core.database import async_session_maker
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
from app.core.qdrant_client import get_qdrant_client
async def test_metadata_filter():
"""测试元数据过滤器构建"""
tenant_id = "szmp@ash@2026"
# 测试上下文 - 模拟用户查询"初二数学痛点"
test_context = {
"grade": "初二",
"subject": "通用",
"kb_scene": "痛点"
}
print("=" * 60)
print("测试元数据过滤器构建")
print("=" * 60)
print(f"租户: {tenant_id}")
print(f"查询上下文: {json.dumps(test_context, ensure_ascii=False)}")
print()
async with async_session_maker() as session:
# 1. 测试过滤器构建
filter_builder = MetadataFilterBuilder(session)
result = await filter_builder.build_filter(tenant_id, test_context)
print("过滤器构建结果:")
print(f" 成功: {result.success}")
print(f" 应用的过滤器: {json.dumps(result.applied_filter, ensure_ascii=False, indent=2)}")
print(f" 缺失的必填字段: {result.missing_required_slots}")
print(f" 调试信息: {json.dumps(result.debug_info, ensure_ascii=False, indent=2)}")
print()
# 2. 获取可过滤字段列表
filter_schema = await filter_builder.get_filter_schema(tenant_id)
print("可过滤字段配置:")
for field in filter_schema:
print(f" - {field['field_key']}: {field['label']} (类型: {field['type']}, 必填: {field['required']})")
if field['options']:
print(f" 选项: {field['options']}")
print()
async def test_kb_search():
"""测试KB向量检索带元数据过滤"""
tenant_id = "szmp@ash@2026"
kb_id = "your_kb_id" # 需要替换为实际的知识库ID
# 测试查询
query = "初二学生数学学习有什么困难"
# 测试上下文
context = {
"grade": "初二",
"subject": "数学",
"kb_scene": "痛点"
}
print("=" * 60)
print("测试KB向量检索带元数据过滤")
print("=" * 60)
print(f"租户: {tenant_id}")
print(f"知识库: {kb_id}")
print(f"查询: {query}")
print(f"上下文: {json.dumps(context, ensure_ascii=False)}")
print()
async with async_session_maker() as session:
# 1. 先构建过滤器
filter_builder = MetadataFilterBuilder(session)
filter_result = await filter_builder.build_filter(tenant_id, context)
print(f"过滤器: {json.dumps(filter_result.applied_filter, ensure_ascii=False)}")
print()
# 执行检索 - 使用更长的超时时间
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.default_kb_tool_runner import KbToolConfig
config = KbToolConfig(
enabled=True,
top_k=5,
timeout_ms=10000, # 10秒超时
min_score_threshold=0.5,
)
kb_runner = DefaultKbToolRunner(
timeout_governor=TimeoutGovernor(),
config=config,
)
# 获取可用的KB列表
from app.services.knowledge_base_service import KnowledgeBaseService
kb_service = KnowledgeBaseService(session)
kbs = await kb_service.list_knowledge_bases(tenant_id)
if not kbs:
print("未找到知识库,请先创建知识库并上传文档")
return
print(f"找到 {len(kbs)} 个知识库:")
for kb in kbs:
print(f" - {kb.name} (ID: {kb.id})")
print()
# 使用第一个知识库进行测试
test_kb_id = str(kbs[0].id)
print(f"使用知识库: {kbs[0].name} (ID: {test_kb_id})")
print()
# 执行检索
result = await kb_runner.execute(
tenant_id=tenant_id,
query=query,
metadata_filter=filter_result.applied_filter
)
print("检索结果:")
print(f" 成功: {result.success}")
print(f" 命中数: {len(result.hits)}")
print(f" 回退原因: {result.fallback_reason_code}")
print()
if result.hits:
print("命中文档:")
for i, hit in enumerate(result.hits, 1):
print(f"\n [{i}] 分数: {hit.score:.4f}")
print(f" 内容: {hit.text[:200]}...")
print(f" 元数据: {json.dumps(hit.metadata, ensure_ascii=False)}")
else:
print("未命中任何文档")
print("\n可能原因:")
print(" 1. 知识库中没有匹配的文档")
print(" 2. 元数据过滤器过于严格")
print(" 3. 向量相似度阈值过高")
async def main():
print("\n" + "=" * 60)
print("KB元数据过滤查询测试")
print("=" * 60 + "\n")
try:
# 测试1: 过滤器构建
await test_metadata_filter()
print("\n" + "=" * 60 + "\n")
# 测试2: 向量检索
await test_kb_search()
except Exception as e:
print(f"\n测试失败: {e}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())