feat(AISVC-T6.9): 集成Ollama嵌入模型修复RAG检索问题
## 问题修复 - 替换假嵌入(SHA256 hash)为真实Ollama nomic-embed-text嵌入 - 修复Qdrant客户端版本不兼容导致score_threshold参数失效 - 降低默认分数阈值从0.7到0.3 ## 新增文件 - ai-service/app/services/embedding/ollama_embedding.py ## 修改文件 - ai-service/app/api/admin/kb.py: 索引任务使用真实嵌入 - ai-service/app/core/config.py: 新增Ollama配置,向量维度改为768 - ai-service/app/core/qdrant_client.py: 移除score_threshold参数 - ai-service/app/services/retrieval/vector_retriever.py: 使用Ollama嵌入
This commit is contained in:
parent
5148c6ef42
commit
4b64a4dbf4
|
|
@ -212,14 +212,13 @@ async def upload_document(
|
||||||
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes):
|
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes):
|
||||||
"""
|
"""
|
||||||
Background indexing task.
|
Background indexing task.
|
||||||
For MVP, we simulate indexing with a simple text extraction.
|
Uses Ollama nomic-embed-text for real embeddings.
|
||||||
In production, this would use a task queue like Celery.
|
|
||||||
"""
|
"""
|
||||||
from app.core.database import async_session_maker
|
from app.core.database import async_session_maker
|
||||||
from app.services.kb import KBService
|
from app.services.kb import KBService
|
||||||
from app.core.qdrant_client import get_qdrant_client
|
from app.core.qdrant_client import get_qdrant_client
|
||||||
|
from app.services.embedding.ollama_embedding import get_embedding
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
import hashlib
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
@ -241,24 +240,12 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
||||||
|
|
||||||
points = []
|
points = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
hash_obj = hashlib.sha256(chunk.encode())
|
embedding = await get_embedding(chunk)
|
||||||
hash_bytes = hash_obj.digest()
|
|
||||||
embedding = []
|
|
||||||
for j in range(0, min(len(hash_bytes) * 8, 1536)):
|
|
||||||
byte_idx = j // 8
|
|
||||||
bit_idx = j % 8
|
|
||||||
if byte_idx < len(hash_bytes):
|
|
||||||
val = (hash_bytes[byte_idx] >> bit_idx) & 1
|
|
||||||
embedding.append(float(val))
|
|
||||||
else:
|
|
||||||
embedding.append(0.0)
|
|
||||||
while len(embedding) < 1536:
|
|
||||||
embedding.append(0.0)
|
|
||||||
|
|
||||||
points.append(
|
points.append(
|
||||||
PointStruct(
|
PointStruct(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
vector=embedding[:1536],
|
vector=embedding,
|
||||||
payload={
|
payload={
|
||||||
"text": chunk,
|
"text": chunk,
|
||||||
"source": doc_id,
|
"source": doc_id,
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,13 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
qdrant_url: str = "http://localhost:6333"
|
qdrant_url: str = "http://localhost:6333"
|
||||||
qdrant_collection_prefix: str = "kb_"
|
qdrant_collection_prefix: str = "kb_"
|
||||||
qdrant_vector_size: int = 1536
|
qdrant_vector_size: int = 768
|
||||||
|
|
||||||
|
ollama_base_url: str = "http://localhost:11434"
|
||||||
|
ollama_embedding_model: str = "nomic-embed-text"
|
||||||
|
|
||||||
rag_top_k: int = 5
|
rag_top_k: int = 5
|
||||||
rag_score_threshold: float = 0.7
|
rag_score_threshold: float = 0.3
|
||||||
rag_min_hits: int = 1
|
rag_min_hits: int = 1
|
||||||
rag_max_evidence_tokens: int = 2000
|
rag_max_evidence_tokens: int = 2000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,6 @@ class QdrantClient:
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
score_threshold=score_threshold,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hits = [
|
hits = [
|
||||||
|
|
@ -129,6 +128,7 @@ class QdrantClient:
|
||||||
"payload": result.payload or {},
|
"payload": result.payload or {},
|
||||||
}
|
}
|
||||||
for result in results
|
for result in results
|
||||||
|
if score_threshold is None or result.score >= score_threshold
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
"""
|
||||||
|
Ollama embedding service for generating text embeddings.
|
||||||
|
Uses nomic-embed-text model via Ollama API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import httpx
|
||||||
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding(text: str) -> list[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding vector for text using Ollama nomic-embed-text model.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.ollama_base_url}/api/embeddings",
|
||||||
|
json={
|
||||||
|
"model": settings.ollama_embedding_model,
|
||||||
|
"prompt": text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
embedding = data.get("embedding", [])
|
||||||
|
|
||||||
|
if not embedding:
|
||||||
|
logger.warning(f"Empty embedding returned for text length={len(text)}")
|
||||||
|
return [0.0] * settings.qdrant_vector_size
|
||||||
|
|
||||||
|
logger.debug(f"Generated embedding: dim={len(embedding)}")
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"Ollama API error: {e.response.status_code} - {e.response.text}")
|
||||||
|
raise
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Ollama connection error: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding generation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings_batch(texts: list[str]) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Generate embedding vectors for multiple texts.
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
embedding = await get_embedding(text)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
return embeddings
|
||||||
|
|
@ -119,30 +119,11 @@ class VectorRetriever(BaseRetriever):
|
||||||
|
|
||||||
async def _get_embedding(self, text: str) -> list[float]:
|
async def _get_embedding(self, text: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for text.
|
Generate embedding for text using Ollama nomic-embed-text model.
|
||||||
[AC-AISVC-16] Placeholder for embedding generation.
|
|
||||||
|
|
||||||
TODO: Integrate with actual embedding provider (OpenAI, local model, etc.)
|
|
||||||
"""
|
"""
|
||||||
import hashlib
|
from app.services.embedding.ollama_embedding import get_embedding as get_ollama_embedding
|
||||||
|
|
||||||
hash_obj = hashlib.sha256(text.encode())
|
return await get_ollama_embedding(text)
|
||||||
hash_bytes = hash_obj.digest()
|
|
||||||
|
|
||||||
embedding = []
|
|
||||||
for i in range(0, min(len(hash_bytes) * 8, settings.qdrant_vector_size)):
|
|
||||||
byte_idx = i // 8
|
|
||||||
bit_idx = i % 8
|
|
||||||
if byte_idx < len(hash_bytes):
|
|
||||||
val = (hash_bytes[byte_idx] >> bit_idx) & 1
|
|
||||||
embedding.append(float(val))
|
|
||||||
else:
|
|
||||||
embedding.append(0.0)
|
|
||||||
|
|
||||||
while len(embedding) < settings.qdrant_vector_size:
|
|
||||||
embedding.append(0.0)
|
|
||||||
|
|
||||||
return embedding[: settings.qdrant_vector_size]
|
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue