260 lines
9.0 KiB
Python
260 lines
9.0 KiB
Python
|
|
"""
|
|||
|
|
知识库检索性能分析脚本
|
|||
|
|
详细分析每个环节的耗时
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|||
|
|
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|||
|
|
from sqlalchemy.orm import sessionmaker
|
|||
|
|
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
|
|||
|
|
from app.services.retrieval.vector_retriever import VectorRetriever
|
|||
|
|
from app.core.config import get_settings
|
|||
|
|
from app.core.qdrant_client import QdrantClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def profile_embedding_generation(query: str):
|
|||
|
|
"""分析 embedding 生成耗时"""
|
|||
|
|
from app.services.embedding import get_embedding_provider
|
|||
|
|
|
|||
|
|
start = time.time()
|
|||
|
|
embedding_service = await get_embedding_provider()
|
|||
|
|
init_time = time.time() - start
|
|||
|
|
|
|||
|
|
start = time.time()
|
|||
|
|
embedding = await embedding_service.embed_query(query)
|
|||
|
|
embed_time = time.time() - start
|
|||
|
|
|
|||
|
|
# 获取 embedding 向量(兼容不同 provider)
|
|||
|
|
if hasattr(embedding, 'embedding_full'):
|
|||
|
|
vector = embedding.embedding_full
|
|||
|
|
elif hasattr(embedding, 'embedding'):
|
|||
|
|
vector = embedding.embedding
|
|||
|
|
else:
|
|||
|
|
vector = embedding
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"init_time_ms": init_time * 1000,
|
|||
|
|
"embed_time_ms": embed_time * 1000,
|
|||
|
|
"dimension": len(vector),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None):
|
|||
|
|
"""分析 Qdrant 搜索耗时"""
|
|||
|
|
from app.core.qdrant_client import get_qdrant_client
|
|||
|
|
|
|||
|
|
client = await get_qdrant_client()
|
|||
|
|
|
|||
|
|
# 获取 collections
|
|||
|
|
start = time.time()
|
|||
|
|
qdrant_client = await client.get_client()
|
|||
|
|
collections = await qdrant_client.get_collections()
|
|||
|
|
safe_tenant_id = tenant_id.replace('@', '_')
|
|||
|
|
prefix = f"kb_{safe_tenant_id}"
|
|||
|
|
tenant_collections = [
|
|||
|
|
c.name for c in collections.collections
|
|||
|
|
if c.name.startswith(prefix)
|
|||
|
|
]
|
|||
|
|
list_collections_time = time.time() - start
|
|||
|
|
|
|||
|
|
# 逐个 collection 搜索
|
|||
|
|
collection_times = []
|
|||
|
|
for collection_name in tenant_collections:
|
|||
|
|
start = time.time()
|
|||
|
|
exists = await qdrant_client.collection_exists(collection_name)
|
|||
|
|
check_time = time.time() - start
|
|||
|
|
|
|||
|
|
if not exists:
|
|||
|
|
collection_times.append({
|
|||
|
|
"collection": collection_name,
|
|||
|
|
"exists": False,
|
|||
|
|
"time_ms": check_time * 1000,
|
|||
|
|
})
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
start = time.time()
|
|||
|
|
# 构建 filter
|
|||
|
|
qdrant_filter = None
|
|||
|
|
if metadata_filter:
|
|||
|
|
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|||
|
|
must_conditions = []
|
|||
|
|
for key, value in metadata_filter.items():
|
|||
|
|
field_path = f"metadata.{key}"
|
|||
|
|
condition = FieldCondition(
|
|||
|
|
key=field_path,
|
|||
|
|
match=MatchValue(value=value),
|
|||
|
|
)
|
|||
|
|
must_conditions.append(condition)
|
|||
|
|
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
results = await qdrant_client.query_points(
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
query=query_vector,
|
|||
|
|
using="full", # 使用 full 向量
|
|||
|
|
limit=5,
|
|||
|
|
score_threshold=0.5,
|
|||
|
|
query_filter=qdrant_filter,
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
if "vector name" in str(e).lower():
|
|||
|
|
# 尝试不使用 vector name
|
|||
|
|
results = await qdrant_client.query_points(
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
query=query_vector,
|
|||
|
|
limit=5,
|
|||
|
|
score_threshold=0.5,
|
|||
|
|
query_filter=qdrant_filter,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
raise
|
|||
|
|
search_time = time.time() - start
|
|||
|
|
|
|||
|
|
collection_times.append({
|
|||
|
|
"collection": collection_name,
|
|||
|
|
"exists": True,
|
|||
|
|
"check_time_ms": check_time * 1000,
|
|||
|
|
"search_time_ms": search_time * 1000,
|
|||
|
|
"results_count": len(results.points),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"list_collections_time_ms": list_collections_time * 1000,
|
|||
|
|
"collections_count": len(tenant_collections),
|
|||
|
|
"collection_times": collection_times,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def profile_full_kb_search():
|
|||
|
|
"""分析完整的知识库搜索流程"""
|
|||
|
|
settings = get_settings()
|
|||
|
|
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("知识库检索性能分析")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
# 1. 分析 Embedding 生成
|
|||
|
|
print("\n📊 1. Embedding 生成分析")
|
|||
|
|
print("-" * 80)
|
|||
|
|
query = "三年级语文学习"
|
|||
|
|
embed_result = await profile_embedding_generation(query)
|
|||
|
|
print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms")
|
|||
|
|
print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms")
|
|||
|
|
print(f" 向量维度: {embed_result['dimension']}")
|
|||
|
|
|
|||
|
|
# 2. 分析 Qdrant 搜索
|
|||
|
|
print("\n📊 2. Qdrant 搜索分析")
|
|||
|
|
print("-" * 80)
|
|||
|
|
|
|||
|
|
# 先生成 embedding
|
|||
|
|
from app.services.embedding import get_embedding_provider
|
|||
|
|
embedding_service = await get_embedding_provider()
|
|||
|
|
embedding_result = await embedding_service.embed_query(query)
|
|||
|
|
# 获取 embedding 向量(兼容不同 provider)
|
|||
|
|
if hasattr(embedding_result, 'embedding_full'):
|
|||
|
|
query_vector = embedding_result.embedding_full
|
|||
|
|
elif hasattr(embedding_result, 'embedding'):
|
|||
|
|
query_vector = embedding_result.embedding
|
|||
|
|
else:
|
|||
|
|
query_vector = embedding_result
|
|||
|
|
|
|||
|
|
tenant_id = "szmp@ash@2026"
|
|||
|
|
metadata_filter = {"grade": "三年级", "subject": "语文"}
|
|||
|
|
|
|||
|
|
qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter)
|
|||
|
|
print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms")
|
|||
|
|
print(f" Collections 数量: {qdrant_result['collections_count']}")
|
|||
|
|
print(f"\n 各 Collection 搜索耗时:")
|
|||
|
|
for ct in qdrant_result['collection_times']:
|
|||
|
|
if ct['exists']:
|
|||
|
|
print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)")
|
|||
|
|
else:
|
|||
|
|
print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)")
|
|||
|
|
|
|||
|
|
total_search_time = sum(
|
|||
|
|
ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times']
|
|||
|
|
)
|
|||
|
|
print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms")
|
|||
|
|
|
|||
|
|
# 3. 分析完整流程
|
|||
|
|
print("\n📊 3. 完整 KB Search 流程分析")
|
|||
|
|
print("-" * 80)
|
|||
|
|
|
|||
|
|
engine = create_async_engine(settings.database_url)
|
|||
|
|
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|||
|
|
|
|||
|
|
async with async_session() as session:
|
|||
|
|
config = KbSearchDynamicConfig(
|
|||
|
|
enabled=True,
|
|||
|
|
top_k=5,
|
|||
|
|
timeout_ms=30000, # 30秒
|
|||
|
|
min_score_threshold=0.5,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
tool = KbSearchDynamicTool(session=session, config=config)
|
|||
|
|
|
|||
|
|
# 记录各阶段时间
|
|||
|
|
stages = []
|
|||
|
|
|
|||
|
|
start_total = time.time()
|
|||
|
|
|
|||
|
|
# 执行搜索
|
|||
|
|
start = time.time()
|
|||
|
|
result = await tool.execute(
|
|||
|
|
query=query,
|
|||
|
|
tenant_id=tenant_id,
|
|||
|
|
scene="学习方案",
|
|||
|
|
top_k=5,
|
|||
|
|
context=metadata_filter,
|
|||
|
|
)
|
|||
|
|
total_time = (time.time() - start_total) * 1000
|
|||
|
|
|
|||
|
|
print(f" 总耗时: {total_time:.2f} ms")
|
|||
|
|
print(f" 结果: success={result.success}, hits={len(result.hits)}")
|
|||
|
|
print(f" 工具内部耗时: {result.duration_ms} ms")
|
|||
|
|
|
|||
|
|
# 计算时间差(工具内部 vs 外部测量)
|
|||
|
|
overhead = total_time - result.duration_ms
|
|||
|
|
print(f" 额外开销(初始化等): {overhead:.2f} ms")
|
|||
|
|
|
|||
|
|
# 4. 性能瓶颈分析
|
|||
|
|
print("\n📊 4. 性能瓶颈分析")
|
|||
|
|
print("-" * 80)
|
|||
|
|
|
|||
|
|
embedding_time = embed_result['embed_time_ms']
|
|||
|
|
qdrant_time = total_search_time
|
|||
|
|
total_measured = embedding_time + qdrant_time
|
|||
|
|
|
|||
|
|
print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)")
|
|||
|
|
print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)")
|
|||
|
|
print(f" 其他开销: {total_time - total_measured:.2f} ms")
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("优化建议:")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
if embedding_time > 1000:
|
|||
|
|
print(" ⚠️ Embedding 生成较慢,考虑:")
|
|||
|
|
print(" - 使用更快的 embedding 模型")
|
|||
|
|
print(" - 增加 embedding 服务缓存")
|
|||
|
|
|
|||
|
|
if qdrant_time > 1000:
|
|||
|
|
print(" ⚠️ Qdrant 搜索较慢,考虑:")
|
|||
|
|
print(" - 并行查询多个 collections")
|
|||
|
|
print(" - 优化 Qdrant 索引")
|
|||
|
|
print(" - 减少 collections 数量")
|
|||
|
|
|
|||
|
|
if len(qdrant_result['collection_times']) > 3:
|
|||
|
|
print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)")
|
|||
|
|
print(" - 建议合并或归档空/少数据的 collections")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(profile_full_kb_search())
|