510 lines
18 KiB
Python
510 lines
18 KiB
Python
"""
|
|
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
|
|
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
|
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
|
from app.services.retrieval.base import (
|
|
BaseRetriever,
|
|
RetrievalContext,
|
|
RetrievalHit,
|
|
RetrievalResult,
|
|
)
|
|
from app.services.retrieval.metadata import (
|
|
ChunkMetadata,
|
|
MetadataFilter,
|
|
RetrieveResult,
|
|
RetrievalStrategy,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = get_settings()
|
|
|
|
|
|
@dataclass
|
|
class TwoStageResult:
|
|
"""Result from two-stage retrieval."""
|
|
candidates: list[dict[str, Any]]
|
|
final_results: list[RetrieveResult]
|
|
stage1_latency_ms: float = 0.0
|
|
stage2_latency_ms: float = 0.0
|
|
|
|
|
|
class RRFCombiner:
|
|
"""
|
|
Reciprocal Rank Fusion for combining multiple retrieval results.
|
|
Reference: rag-optimization/spec.md Section 2.5
|
|
|
|
Formula: score = Σ(1 / (k + rank_i))
|
|
Default k = 60
|
|
"""
|
|
|
|
def __init__(self, k: int = 60):
|
|
self._k = k
|
|
|
|
def combine(
|
|
self,
|
|
vector_results: list[dict[str, Any]],
|
|
bm25_results: list[dict[str, Any]],
|
|
vector_weight: float = 0.7,
|
|
bm25_weight: float = 0.3,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Combine vector and BM25 results using RRF.
|
|
|
|
Args:
|
|
vector_results: Results from vector search
|
|
bm25_results: Results from BM25 search
|
|
vector_weight: Weight for vector results
|
|
bm25_weight: Weight for BM25 results
|
|
|
|
Returns:
|
|
Combined and sorted results
|
|
"""
|
|
combined_scores: dict[str, dict[str, Any]] = {}
|
|
|
|
for rank, result in enumerate(vector_results):
|
|
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
|
rrf_score = vector_weight / (self._k + rank + 1)
|
|
|
|
if chunk_id not in combined_scores:
|
|
combined_scores[chunk_id] = {
|
|
"score": 0.0,
|
|
"vector_score": result.get("score", 0.0),
|
|
"bm25_score": 0.0,
|
|
"vector_rank": rank,
|
|
"bm25_rank": -1,
|
|
"payload": result.get("payload", {}),
|
|
"id": chunk_id,
|
|
}
|
|
|
|
combined_scores[chunk_id]["score"] += rrf_score
|
|
|
|
for rank, result in enumerate(bm25_results):
|
|
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
|
rrf_score = bm25_weight / (self._k + rank + 1)
|
|
|
|
if chunk_id not in combined_scores:
|
|
combined_scores[chunk_id] = {
|
|
"score": 0.0,
|
|
"vector_score": 0.0,
|
|
"bm25_score": result.get("score", 0.0),
|
|
"vector_rank": -1,
|
|
"bm25_rank": rank,
|
|
"payload": result.get("payload", {}),
|
|
"id": chunk_id,
|
|
}
|
|
else:
|
|
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
|
combined_scores[chunk_id]["bm25_rank"] = rank
|
|
|
|
combined_scores[chunk_id]["score"] += rrf_score
|
|
|
|
sorted_results = sorted(
|
|
combined_scores.values(),
|
|
key=lambda x: x["score"],
|
|
reverse=True
|
|
)
|
|
|
|
return sorted_results
|
|
|
|
|
|
class OptimizedRetriever(BaseRetriever):
|
|
"""
|
|
Optimized retriever with:
|
|
- Task prefixes (search_document/search_query)
|
|
- Two-stage retrieval (256 dim -> 768 dim)
|
|
- RRF hybrid ranking (vector + BM25)
|
|
- Metadata filtering
|
|
|
|
Reference: rag-optimization/spec.md Section 2, 3, 4
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
qdrant_client: QdrantClient | None = None,
|
|
embedding_provider: NomicEmbeddingProvider | None = None,
|
|
top_k: int | None = None,
|
|
score_threshold: float | None = None,
|
|
min_hits: int | None = None,
|
|
two_stage_enabled: bool | None = None,
|
|
two_stage_expand_factor: int | None = None,
|
|
hybrid_enabled: bool | None = None,
|
|
rrf_k: int | None = None,
|
|
):
|
|
self._qdrant_client = qdrant_client
|
|
self._embedding_provider = embedding_provider
|
|
self._top_k = top_k or settings.rag_top_k
|
|
self._score_threshold = score_threshold or settings.rag_score_threshold
|
|
self._min_hits = min_hits or settings.rag_min_hits
|
|
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
|
|
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
|
|
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
|
|
self._rrf_k = rrf_k or settings.rag_rrf_k
|
|
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
|
|
|
|
async def _get_client(self) -> QdrantClient:
|
|
if self._qdrant_client is None:
|
|
self._qdrant_client = await get_qdrant_client()
|
|
return self._qdrant_client
|
|
|
|
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
|
if self._embedding_provider is None:
|
|
from app.services.embedding.factory import get_embedding_config_manager
|
|
manager = get_embedding_config_manager()
|
|
provider = await manager.get_provider()
|
|
if isinstance(provider, NomicEmbeddingProvider):
|
|
self._embedding_provider = provider
|
|
else:
|
|
self._embedding_provider = NomicEmbeddingProvider(
|
|
base_url=settings.ollama_base_url,
|
|
model=settings.ollama_embedding_model,
|
|
dimension=settings.qdrant_vector_size,
|
|
)
|
|
return self._embedding_provider
|
|
|
|
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
|
"""
|
|
Retrieve documents using optimized strategy.
|
|
|
|
Strategy selection:
|
|
1. If two_stage_enabled: use two-stage retrieval
|
|
2. If hybrid_enabled: use RRF hybrid ranking
|
|
3. Otherwise: simple vector search
|
|
"""
|
|
logger.info(
|
|
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
|
|
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
|
|
)
|
|
logger.info(
|
|
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
|
|
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
|
|
)
|
|
|
|
try:
|
|
provider = await self._get_embedding_provider()
|
|
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
|
|
|
|
embedding_result = await provider.embed_query(ctx.query)
|
|
logger.info(
|
|
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
|
|
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
|
)
|
|
|
|
if self._two_stage_enabled:
|
|
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
|
results = await self._two_stage_retrieve(
|
|
ctx.tenant_id,
|
|
embedding_result,
|
|
self._top_k,
|
|
)
|
|
elif self._hybrid_enabled:
|
|
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
|
|
results = await self._hybrid_retrieve(
|
|
ctx.tenant_id,
|
|
embedding_result,
|
|
ctx.query,
|
|
self._top_k,
|
|
)
|
|
else:
|
|
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
|
|
results = await self._vector_retrieve(
|
|
ctx.tenant_id,
|
|
embedding_result.embedding_full,
|
|
self._top_k,
|
|
)
|
|
|
|
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
|
|
|
|
retrieval_hits = [
|
|
RetrievalHit(
|
|
text=result.get("payload", {}).get("text", ""),
|
|
score=result.get("score", 0.0),
|
|
source="optimized_rag",
|
|
metadata=result.get("payload", {}),
|
|
)
|
|
for result in results
|
|
if result.get("score", 0.0) >= self._score_threshold
|
|
]
|
|
|
|
filtered_count = len(results) - len(retrieval_hits)
|
|
if filtered_count > 0:
|
|
logger.info(
|
|
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
|
|
)
|
|
|
|
is_insufficient = len(retrieval_hits) < self._min_hits
|
|
|
|
diagnostics = {
|
|
"query_length": len(ctx.query),
|
|
"top_k": self._top_k,
|
|
"score_threshold": self._score_threshold,
|
|
"two_stage_enabled": self._two_stage_enabled,
|
|
"hybrid_enabled": self._hybrid_enabled,
|
|
"total_hits": len(retrieval_hits),
|
|
"is_insufficient": is_insufficient,
|
|
"max_score": max((h.score for h in retrieval_hits), default=0.0),
|
|
"raw_results_count": len(results),
|
|
"filtered_below_threshold": filtered_count,
|
|
}
|
|
|
|
logger.info(
|
|
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
|
|
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
|
|
)
|
|
|
|
if len(retrieval_hits) == 0:
|
|
logger.warning(
|
|
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
|
|
f"raw_results={len(results)}, threshold={self._score_threshold}"
|
|
)
|
|
|
|
return RetrievalResult(
|
|
hits=retrieval_hits,
|
|
diagnostics=diagnostics,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
|
|
return RetrievalResult(
|
|
hits=[],
|
|
diagnostics={"error": str(e), "is_insufficient": True},
|
|
)
|
|
|
|
async def _two_stage_retrieve(
|
|
self,
|
|
tenant_id: str,
|
|
embedding_result: NomicEmbeddingResult,
|
|
top_k: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Two-stage retrieval using Matryoshka dimensions.
|
|
|
|
Stage 1: Fast retrieval with 256-dim vectors
|
|
Stage 2: Precise reranking with 768-dim vectors
|
|
|
|
Reference: rag-optimization/spec.md Section 2.4
|
|
"""
|
|
import time
|
|
|
|
client = await self._get_client()
|
|
|
|
stage1_start = time.perf_counter()
|
|
candidates = await self._search_with_dimension(
|
|
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
|
top_k * self._two_stage_expand_factor
|
|
)
|
|
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
|
|
|
logger.debug(
|
|
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
|
)
|
|
|
|
stage2_start = time.perf_counter()
|
|
reranked = []
|
|
for candidate in candidates:
|
|
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
|
|
if stored_full_embedding:
|
|
import numpy as np
|
|
similarity = self._cosine_similarity(
|
|
embedding_result.embedding_full,
|
|
stored_full_embedding
|
|
)
|
|
candidate["score"] = similarity
|
|
candidate["stage"] = "reranked"
|
|
reranked.append(candidate)
|
|
|
|
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
results = reranked[:top_k]
|
|
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
|
|
|
logger.debug(
|
|
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
|
)
|
|
|
|
return results
|
|
|
|
async def _hybrid_retrieve(
|
|
self,
|
|
tenant_id: str,
|
|
embedding_result: NomicEmbeddingResult,
|
|
query: str,
|
|
top_k: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Hybrid retrieval using RRF to combine vector and BM25 results.
|
|
|
|
Reference: rag-optimization/spec.md Section 2.5
|
|
"""
|
|
client = await self._get_client()
|
|
|
|
vector_task = self._search_with_dimension(
|
|
client, tenant_id, embedding_result.embedding_full, "full",
|
|
top_k * 2
|
|
)
|
|
|
|
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
|
|
|
|
vector_results, bm25_results = await asyncio.gather(
|
|
vector_task, bm25_task, return_exceptions=True
|
|
)
|
|
|
|
if isinstance(vector_results, Exception):
|
|
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
|
|
vector_results = []
|
|
|
|
if isinstance(bm25_results, Exception):
|
|
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
|
|
bm25_results = []
|
|
|
|
combined = self._rrf_combiner.combine(
|
|
vector_results,
|
|
bm25_results,
|
|
vector_weight=settings.rag_vector_weight,
|
|
bm25_weight=settings.rag_bm25_weight,
|
|
)
|
|
|
|
return combined[:top_k]
|
|
|
|
async def _vector_retrieve(
|
|
self,
|
|
tenant_id: str,
|
|
embedding: list[float],
|
|
top_k: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""Simple vector retrieval."""
|
|
client = await self._get_client()
|
|
return await self._search_with_dimension(
|
|
client, tenant_id, embedding, "full", top_k
|
|
)
|
|
|
|
async def _search_with_dimension(
|
|
self,
|
|
client: QdrantClient,
|
|
tenant_id: str,
|
|
query_vector: list[float],
|
|
vector_name: str,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""Search using specified vector dimension."""
|
|
try:
|
|
qdrant = await client.get_client()
|
|
collection_name = client.get_collection_name(tenant_id)
|
|
|
|
logger.info(
|
|
f"[RAG-OPT] Searching collection={collection_name}, "
|
|
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
|
|
)
|
|
|
|
results = await qdrant.search(
|
|
collection_name=collection_name,
|
|
query_vector=(vector_name, query_vector),
|
|
limit=limit,
|
|
)
|
|
|
|
logger.info(
|
|
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
|
|
)
|
|
|
|
if len(results) > 0:
|
|
for i, r in enumerate(results[:3]):
|
|
logger.debug(
|
|
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": str(result.id),
|
|
"score": result.score,
|
|
"payload": result.payload or {},
|
|
}
|
|
for result in results
|
|
]
|
|
except Exception as e:
|
|
logger.error(
|
|
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
|
|
f"collection_name={client.get_collection_name(tenant_id)}",
|
|
exc_info=True
|
|
)
|
|
return []
|
|
|
|
async def _bm25_search(
|
|
self,
|
|
client: QdrantClient,
|
|
tenant_id: str,
|
|
query: str,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
|
|
This is a simplified implementation; for production, use Elasticsearch.
|
|
"""
|
|
try:
|
|
qdrant = await client.get_client()
|
|
collection_name = client.get_collection_name(tenant_id)
|
|
|
|
query_terms = set(re.findall(r'\w+', query.lower()))
|
|
|
|
results = await qdrant.scroll(
|
|
collection_name=collection_name,
|
|
limit=limit * 3,
|
|
with_payload=True,
|
|
)
|
|
|
|
scored_results = []
|
|
for point in results[0]:
|
|
text = point.payload.get("text", "").lower()
|
|
text_terms = set(re.findall(r'\w+', text))
|
|
overlap = len(query_terms & text_terms)
|
|
if overlap > 0:
|
|
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
|
scored_results.append({
|
|
"id": str(point.id),
|
|
"score": score,
|
|
"payload": point.payload or {},
|
|
})
|
|
|
|
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
|
return scored_results[:limit]
|
|
|
|
except Exception as e:
|
|
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
|
|
return []
|
|
|
|
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
import numpy as np
|
|
a = np.array(vec1)
|
|
b = np.array(vec2)
|
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check if retriever is healthy."""
|
|
try:
|
|
client = await self._get_client()
|
|
qdrant = await client.get_client()
|
|
await qdrant.get_collections()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"[RAG-OPT] Health check failed: {e}")
|
|
return False
|
|
|
|
|
|
_optimized_retriever: OptimizedRetriever | None = None
|
|
|
|
|
|
async def get_optimized_retriever() -> OptimizedRetriever:
|
|
"""Get or create OptimizedRetriever instance."""
|
|
global _optimized_retriever
|
|
if _optimized_retriever is None:
|
|
_optimized_retriever = OptimizedRetriever()
|
|
return _optimized_retriever
|