""" 测试 kb_search_dynamic 工具 - 课程咨询场景 """ import asyncio import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from app.services.mid.kb_search_dynamic_tool import ( KbSearchDynamicTool, KbSearchDynamicConfig, StepKbConfig, ) from app.core.config import get_settings async def test_kb_search(): """测试知识库搜索 - 课程咨询场景""" settings = get_settings() engine = create_async_engine(settings.database_url) async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: config = KbSearchDynamicConfig( enabled=True, top_k=10, timeout_ms=15000, min_score_threshold=0.3, ) tool = KbSearchDynamicTool(session=session, config=config) course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" step_kb_config = StepKbConfig( allowed_kb_ids=[course_kb_id], preferred_kb_ids=[course_kb_id], step_id="test_course_query", ) test_params = { "query": "课程介绍", "tenant_id": "szmp@ash@2026", "top_k": 10, "context": { "grade": "五年级", }, "step_kb_config": step_kb_config, } print(f"\n{'='*80}") print(f"测试: kb_search_dynamic - 课程知识库") print(f"{'='*80}") print(f"参数: query={test_params['query']}") print(f" tenant_id={test_params['tenant_id']}") print(f" context={test_params['context']}") print(f" step_kb_config.allowed_kb_ids={step_kb_config.allowed_kb_ids}") print(f"超时设置: {config.timeout_ms}ms") print(f"最低分数阈值: {config.min_score_threshold}") try: result = await tool.execute(**test_params) print(f"\n结果:") print(f" success: {result.success}") print(f" hits count: {len(result.hits)}") print(f" applied_filter: {result.applied_filter}") print(f" fallback_reason_code: {result.fallback_reason_code}") print(f" duration_ms: {result.duration_ms}") if result.filter_debug: print(f" filter_debug: {result.filter_debug}") if result.step_kb_binding: print(f" step_kb_binding: {result.step_kb_binding}") if result.tool_trace: print(f"\n Tool Trace:") print(f" tool_name: {result.tool_trace.tool_name}") print(f" status: {result.tool_trace.status}") print(f" duration_ms: {result.tool_trace.duration_ms}") print(f" args_digest: {result.tool_trace.args_digest}") print(f" result_digest: {result.tool_trace.result_digest}") if hasattr(result.tool_trace, 'arguments') and result.tool_trace.arguments: print(f" arguments: {result.tool_trace.arguments}") if result.hits: print(f"\n 检索结果 (共 {len(result.hits)} 条):") for i, hit in enumerate(result.hits, 1): text = hit.get('text', '') text_preview = text[:200] + '...' if len(text) > 200 else text score = hit.get('score', 0) metadata = hit.get('metadata', {}) collection = hit.get('collection', 'unknown') kb_id = hit.get('kb_id', 'unknown') print(f"\n [{i}] score={score:.4f}") print(f" collection: {collection}") print(f" kb_id: {kb_id}") print(f" metadata: {metadata}") print(f" text: {text_preview}") else: print(f"\n ⚠️ 没有命中任何结果") print(f" 请检查:") print(f" 1. 知识库是否有数据") print(f" 2. 向量是否正确生成") print(f" 3. 过滤条件是否过于严格") except Exception as e: print(f"\n ❌ 错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": asyncio.run(test_kb_search())