""" 知识库检索性能分析脚本 详细分析每个环节的耗时 """ import asyncio import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig from app.services.retrieval.vector_retriever import VectorRetriever from app.core.config import get_settings from app.core.qdrant_client import QdrantClient async def profile_embedding_generation(query: str): """分析 embedding 生成耗时""" from app.services.embedding import get_embedding_provider start = time.time() embedding_service = await get_embedding_provider() init_time = time.time() - start start = time.time() embedding = await embedding_service.embed_query(query) embed_time = time.time() - start # 获取 embedding 向量(兼容不同 provider) if hasattr(embedding, 'embedding_full'): vector = embedding.embedding_full elif hasattr(embedding, 'embedding'): vector = embedding.embedding else: vector = embedding return { "init_time_ms": init_time * 1000, "embed_time_ms": embed_time * 1000, "dimension": len(vector), } async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None): """分析 Qdrant 搜索耗时""" from app.core.qdrant_client import get_qdrant_client client = await get_qdrant_client() # 获取 collections start = time.time() qdrant_client = await client.get_client() collections = await qdrant_client.get_collections() safe_tenant_id = tenant_id.replace('@', '_') prefix = f"kb_{safe_tenant_id}" tenant_collections = [ c.name for c in collections.collections if c.name.startswith(prefix) ] list_collections_time = time.time() - start # 逐个 collection 搜索 collection_times = [] for collection_name in tenant_collections: start = time.time() exists = await qdrant_client.collection_exists(collection_name) check_time = time.time() - start if not exists: collection_times.append({ "collection": collection_name, "exists": False, "time_ms": check_time * 1000, }) continue start = time.time() # 构建 filter qdrant_filter = None if metadata_filter: from qdrant_client.models import FieldCondition, Filter, MatchValue must_conditions = [] for key, value in metadata_filter.items(): field_path = f"metadata.{key}" condition = FieldCondition( key=field_path, match=MatchValue(value=value), ) must_conditions.append(condition) qdrant_filter = Filter(must=must_conditions) if must_conditions else None try: results = await qdrant_client.query_points( collection_name=collection_name, query=query_vector, using="full", # 使用 full 向量 limit=5, score_threshold=0.5, query_filter=qdrant_filter, ) except Exception as e: if "vector name" in str(e).lower(): # 尝试不使用 vector name results = await qdrant_client.query_points( collection_name=collection_name, query=query_vector, limit=5, score_threshold=0.5, query_filter=qdrant_filter, ) else: raise search_time = time.time() - start collection_times.append({ "collection": collection_name, "exists": True, "check_time_ms": check_time * 1000, "search_time_ms": search_time * 1000, "results_count": len(results.points), }) return { "list_collections_time_ms": list_collections_time * 1000, "collections_count": len(tenant_collections), "collection_times": collection_times, } async def profile_full_kb_search(): """分析完整的知识库搜索流程""" settings = get_settings() print("=" * 80) print("知识库检索性能分析") print("=" * 80) # 1. 分析 Embedding 生成 print("\n📊 1. Embedding 生成分析") print("-" * 80) query = "三年级语文学习" embed_result = await profile_embedding_generation(query) print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms") print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms") print(f" 向量维度: {embed_result['dimension']}") # 2. 分析 Qdrant 搜索 print("\n📊 2. Qdrant 搜索分析") print("-" * 80) # 先生成 embedding from app.services.embedding import get_embedding_provider embedding_service = await get_embedding_provider() embedding_result = await embedding_service.embed_query(query) # 获取 embedding 向量(兼容不同 provider) if hasattr(embedding_result, 'embedding_full'): query_vector = embedding_result.embedding_full elif hasattr(embedding_result, 'embedding'): query_vector = embedding_result.embedding else: query_vector = embedding_result tenant_id = "szmp@ash@2026" metadata_filter = {"grade": "三年级", "subject": "语文"} qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter) print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms") print(f" Collections 数量: {qdrant_result['collections_count']}") print(f"\n 各 Collection 搜索耗时:") for ct in qdrant_result['collection_times']: if ct['exists']: print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)") else: print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)") total_search_time = sum( ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times'] ) print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms") # 3. 分析完整流程 print("\n📊 3. 完整 KB Search 流程分析") print("-" * 80) engine = create_async_engine(settings.database_url) async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: config = KbSearchDynamicConfig( enabled=True, top_k=5, timeout_ms=30000, # 30秒 min_score_threshold=0.5, ) tool = KbSearchDynamicTool(session=session, config=config) # 记录各阶段时间 stages = [] start_total = time.time() # 执行搜索 start = time.time() result = await tool.execute( query=query, tenant_id=tenant_id, scene="学习方案", top_k=5, context=metadata_filter, ) total_time = (time.time() - start_total) * 1000 print(f" 总耗时: {total_time:.2f} ms") print(f" 结果: success={result.success}, hits={len(result.hits)}") print(f" 工具内部耗时: {result.duration_ms} ms") # 计算时间差(工具内部 vs 外部测量) overhead = total_time - result.duration_ms print(f" 额外开销(初始化等): {overhead:.2f} ms") # 4. 性能瓶颈分析 print("\n📊 4. 性能瓶颈分析") print("-" * 80) embedding_time = embed_result['embed_time_ms'] qdrant_time = total_search_time total_measured = embedding_time + qdrant_time print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)") print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)") print(f" 其他开销: {total_time - total_measured:.2f} ms") print("\n" + "=" * 80) print("优化建议:") print("=" * 80) if embedding_time > 1000: print(" ⚠️ Embedding 生成较慢,考虑:") print(" - 使用更快的 embedding 模型") print(" - 增加 embedding 服务缓存") if qdrant_time > 1000: print(" ⚠️ Qdrant 搜索较慢,考虑:") print(" - 并行查询多个 collections") print(" - 优化 Qdrant 索引") print(" - 减少 collections 数量") if len(qdrant_result['collection_times']) > 3: print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)") print(" - 建议合并或归档空/少数据的 collections") if __name__ == "__main__": asyncio.run(profile_full_kb_search())