feat(AISVC-T6.9): 集成Ollama嵌入模型修复RAG检索问题
## 问题修复 - 替换假嵌入(SHA256 hash)为真实Ollama nomic-embed-text嵌入 - 修复Qdrant客户端版本不兼容导致score_threshold参数失效 - 降低默认分数阈值从0.7到0.3 ## 新增文件 - ai-service/app/services/embedding/ollama_embedding.py ## 修改文件 - ai-service/app/api/admin/kb.py: 索引任务使用真实嵌入 - ai-service/app/core/config.py: 新增Ollama配置,向量维度改为768 - ai-service/app/core/qdrant_client.py: 移除score_threshold参数 - ai-service/app/services/retrieval/vector_retriever.py: 使用Ollama嵌入
This commit is contained in:
parent
5148c6ef42
commit
4b64a4dbf4
|
|
@ -212,14 +212,13 @@ async def upload_document(
|
|||
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes):
|
||||
"""
|
||||
Background indexing task.
|
||||
For MVP, we simulate indexing with a simple text extraction.
|
||||
In production, this would use a task queue like Celery.
|
||||
Uses Ollama nomic-embed-text for real embeddings.
|
||||
"""
|
||||
from app.core.database import async_session_maker
|
||||
from app.services.kb import KBService
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
from app.services.embedding.ollama_embedding import get_embedding
|
||||
from qdrant_client.models import PointStruct
|
||||
import hashlib
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
|
@ -241,24 +240,12 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
|||
|
||||
points = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
hash_obj = hashlib.sha256(chunk.encode())
|
||||
hash_bytes = hash_obj.digest()
|
||||
embedding = []
|
||||
for j in range(0, min(len(hash_bytes) * 8, 1536)):
|
||||
byte_idx = j // 8
|
||||
bit_idx = j % 8
|
||||
if byte_idx < len(hash_bytes):
|
||||
val = (hash_bytes[byte_idx] >> bit_idx) & 1
|
||||
embedding.append(float(val))
|
||||
else:
|
||||
embedding.append(0.0)
|
||||
while len(embedding) < 1536:
|
||||
embedding.append(0.0)
|
||||
|
||||
embedding = await get_embedding(chunk)
|
||||
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=embedding[:1536],
|
||||
vector=embedding,
|
||||
payload={
|
||||
"text": chunk,
|
||||
"source": doc_id,
|
||||
|
|
|
|||
|
|
@ -38,10 +38,13 @@ class Settings(BaseSettings):
|
|||
|
||||
qdrant_url: str = "http://localhost:6333"
|
||||
qdrant_collection_prefix: str = "kb_"
|
||||
qdrant_vector_size: int = 1536
|
||||
qdrant_vector_size: int = 768
|
||||
|
||||
ollama_base_url: str = "http://localhost:11434"
|
||||
ollama_embedding_model: str = "nomic-embed-text"
|
||||
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.7
|
||||
rag_score_threshold: float = 0.3
|
||||
rag_min_hits: int = 1
|
||||
rag_max_evidence_tokens: int = 2000
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,6 @@ class QdrantClient:
|
|||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
hits = [
|
||||
|
|
@ -129,6 +128,7 @@ class QdrantClient:
|
|||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
if score_threshold is None or result.score >= score_threshold
|
||||
]
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Ollama embedding service for generating text embeddings.
|
||||
Uses nomic-embed-text model via Ollama API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_embedding(text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for text using Ollama nomic-embed-text model.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{settings.ollama_base_url}/api/embeddings",
|
||||
json={
|
||||
"model": settings.ollama_embedding_model,
|
||||
"prompt": text,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embedding = data.get("embedding", [])
|
||||
|
||||
if not embedding:
|
||||
logger.warning(f"Empty embedding returned for text length={len(text)}")
|
||||
return [0.0] * settings.qdrant_vector_size
|
||||
|
||||
logger.debug(f"Generated embedding: dim={len(embedding)}")
|
||||
return embedding
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Ollama API error: {e.response.status_code} - {e.response.text}")
|
||||
raise
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Ollama connection error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding generation failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await get_embedding(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
|
@ -119,30 +119,11 @@ class VectorRetriever(BaseRetriever):
|
|||
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for text.
|
||||
[AC-AISVC-16] Placeholder for embedding generation.
|
||||
|
||||
TODO: Integrate with actual embedding provider (OpenAI, local model, etc.)
|
||||
Generate embedding for text using Ollama nomic-embed-text model.
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
hash_obj = hashlib.sha256(text.encode())
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes) * 8, settings.qdrant_vector_size)):
|
||||
byte_idx = i // 8
|
||||
bit_idx = i % 8
|
||||
if byte_idx < len(hash_bytes):
|
||||
val = (hash_bytes[byte_idx] >> bit_idx) & 1
|
||||
embedding.append(float(val))
|
||||
else:
|
||||
embedding.append(0.0)
|
||||
|
||||
while len(embedding) < settings.qdrant_vector_size:
|
||||
embedding.append(0.0)
|
||||
|
||||
return embedding[: settings.qdrant_vector_size]
|
||||
from app.services.embedding.ollama_embedding import get_embedding as get_ollama_embedding
|
||||
|
||||
return await get_ollama_embedding(text)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue