[AC-AISVC-RES-01~15] feat(retrieval): 实现检索策略Pipeline模块
- 新增策略配置模型 (config.py) - GrayscaleConfig: 灰度发布配置 - ModeRouterConfig: 模式路由配置 - MetadataInferenceConfig: 元数据推断配置 - 新增 Pipeline 实现 - DefaultPipeline: 复用现有 OptimizedRetriever 逻辑 - EnhancedPipeline: Dense + Keyword + RRF 组合检索 - 新增路由器 - StrategyRouter: 策略路由器(default/enhanced) - ModeRouter: 模式路由器(direct/react/auto) - 新增 RollbackManager: 回退与审计管理器 - 新增 MetadataInferenceService: 元数据推断统一入口 - 新增单元测试 (51 passed)
This commit is contained in:
parent
9f28498b97
commit
7027097513
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Retrieval Strategy Module for AI Service.
|
||||
[AC-AISVC-RES-01~15] 策略化检索与嵌入模块。
|
||||
|
||||
核心组件:
|
||||
- RetrievalStrategyConfig: 策略配置模型
|
||||
- BasePipeline: Pipeline 抽象基类
|
||||
- DefaultPipeline: 默认策略(复用现有逻辑)
|
||||
- EnhancedPipeline: 增强策略(新端到端流程)
|
||||
- MetadataInferenceService: 元数据推断统一入口
|
||||
- StrategyRouter: 策略路由器
|
||||
- ModeRouter: 模式路由器(direct/react/auto)
|
||||
- RollbackManager: 回退管理器
|
||||
"""
|
||||
|
||||
from app.services.retrieval.strategy.config import (
|
||||
FilterMode,
|
||||
GrayscaleConfig,
|
||||
HybridRetrievalConfig,
|
||||
MetadataInferenceConfig,
|
||||
ModeRouterConfig,
|
||||
PipelineConfig,
|
||||
RerankerConfig,
|
||||
RetrievalStrategyConfig,
|
||||
RuntimeMode,
|
||||
StrategyType,
|
||||
get_strategy_config,
|
||||
set_strategy_config,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import (
|
||||
DefaultPipeline,
|
||||
get_default_pipeline,
|
||||
)
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import (
|
||||
EnhancedPipeline,
|
||||
get_enhanced_pipeline,
|
||||
)
|
||||
from app.services.retrieval.strategy.metadata_inference import (
|
||||
InferenceContext,
|
||||
InferenceResult,
|
||||
MetadataInferenceService,
|
||||
)
|
||||
from app.services.retrieval.strategy.mode_router import (
|
||||
ModeDecision,
|
||||
ModeRouter,
|
||||
get_mode_router,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
MetadataFilterResult,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
from app.services.retrieval.strategy.rollback_manager import (
|
||||
AuditLog,
|
||||
RollbackManager,
|
||||
RollbackResult,
|
||||
RollbackTrigger,
|
||||
get_rollback_manager,
|
||||
)
|
||||
from app.services.retrieval.strategy.strategy_router import (
|
||||
RoutingDecision,
|
||||
StrategyRouter,
|
||||
get_strategy_router,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BasePipeline",
|
||||
"PipelineContext",
|
||||
"PipelineResult",
|
||||
"MetadataFilterResult",
|
||||
"DefaultPipeline",
|
||||
"get_default_pipeline",
|
||||
"EnhancedPipeline",
|
||||
"get_enhanced_pipeline",
|
||||
"RetrievalStrategyConfig",
|
||||
"GrayscaleConfig",
|
||||
"PipelineConfig",
|
||||
"RerankerConfig",
|
||||
"ModeRouterConfig",
|
||||
"HybridRetrievalConfig",
|
||||
"MetadataInferenceConfig",
|
||||
"StrategyType",
|
||||
"FilterMode",
|
||||
"RuntimeMode",
|
||||
"get_strategy_config",
|
||||
"set_strategy_config",
|
||||
"MetadataInferenceService",
|
||||
"InferenceContext",
|
||||
"InferenceResult",
|
||||
"StrategyRouter",
|
||||
"RoutingDecision",
|
||||
"get_strategy_router",
|
||||
"ModeRouter",
|
||||
"ModeDecision",
|
||||
"get_mode_router",
|
||||
"RollbackManager",
|
||||
"RollbackResult",
|
||||
"RollbackTrigger",
|
||||
"AuditLog",
|
||||
"get_rollback_manager",
|
||||
]
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
Retrieval Strategy Configuration.
|
||||
[AC-AISVC-RES-01~15] 检索策略配置模型。
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StrategyType(str, Enum):
|
||||
"""策略类型。"""
|
||||
DEFAULT = "default"
|
||||
ENHANCED = "enhanced"
|
||||
|
||||
|
||||
class RuntimeMode(str, Enum):
|
||||
"""运行时模式。"""
|
||||
DIRECT = "direct"
|
||||
REACT = "react"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class FilterMode(str, Enum):
|
||||
"""过滤模式。"""
|
||||
HARD = "hard"
|
||||
SOFT = "soft"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrayscaleConfig:
|
||||
"""灰度发布配置。【AC-AISVC-RES-03】"""
|
||||
enabled: bool = False
|
||||
percentage: float = 0.0
|
||||
allowlist: list[str] = field(default_factory=list)
|
||||
|
||||
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
|
||||
"""判断是否应该使用增强策略。"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
|
||||
return True
|
||||
|
||||
return random.random() * 100 < self.percentage
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridRetrievalConfig:
|
||||
"""混合检索配置。"""
|
||||
dense_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
rrf_k: int = 60
|
||||
enable_keyword: bool = True
|
||||
keyword_top_k_multiplier: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerConfig:
|
||||
"""重排器配置。【AC-AISVC-RES-08】"""
|
||||
enabled: bool = False
|
||||
model: str = "cross-encoder"
|
||||
top_k_after_rerank: int = 5
|
||||
min_score_threshold: float = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModeRouterConfig:
|
||||
"""模式路由配置。【AC-AISVC-RES-09~15】"""
|
||||
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
|
||||
react_trigger_confidence_threshold: float = 0.6
|
||||
react_trigger_complexity_score: float = 0.5
|
||||
react_max_steps: int = 5
|
||||
direct_fallback_on_low_confidence: bool = True
|
||||
short_query_threshold: int = 20
|
||||
|
||||
def should_use_react(
|
||||
self,
|
||||
query: str,
|
||||
confidence: float | None = None,
|
||||
complexity_score: float | None = None,
|
||||
) -> bool:
|
||||
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
|
||||
if self.runtime_mode == RuntimeMode.REACT:
|
||||
return True
|
||||
if self.runtime_mode == RuntimeMode.DIRECT:
|
||||
return False
|
||||
|
||||
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
|
||||
return False
|
||||
|
||||
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
|
||||
return True
|
||||
|
||||
if confidence and confidence < self.react_trigger_confidence_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataInferenceConfig:
|
||||
"""元数据推断配置。"""
|
||||
enabled: bool = True
|
||||
confidence_high_threshold: float = 0.8
|
||||
confidence_low_threshold: float = 0.5
|
||||
default_filter_mode: FilterMode = FilterMode.SOFT
|
||||
cache_ttl_seconds: int = 300
|
||||
|
||||
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
|
||||
"""根据置信度确定过滤模式。"""
|
||||
if confidence is None:
|
||||
return FilterMode.NONE
|
||||
if confidence >= self.confidence_high_threshold:
|
||||
return FilterMode.HARD
|
||||
if confidence >= self.confidence_low_threshold:
|
||||
return FilterMode.SOFT
|
||||
return FilterMode.NONE
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig:
|
||||
"""Pipeline 配置。"""
|
||||
top_k: int = 5
|
||||
score_threshold: float = 0.01
|
||||
min_hits: int = 1
|
||||
two_stage_enabled: bool = True
|
||||
two_stage_expand_factor: int = 10
|
||||
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalStrategyConfig:
|
||||
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
|
||||
active_strategy: StrategyType = StrategyType.DEFAULT
|
||||
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
|
||||
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
|
||||
reranker: RerankerConfig = field(default_factory=RerankerConfig)
|
||||
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
|
||||
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
|
||||
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
|
||||
"max_latency_ms": 2000.0,
|
||||
"min_success_rate": 0.95,
|
||||
"max_error_rate": 0.05,
|
||||
})
|
||||
|
||||
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
|
||||
"""判断是否启用增强策略。"""
|
||||
if self.active_strategy == StrategyType.ENHANCED:
|
||||
return True
|
||||
return self.grayscale.should_use_enhanced(tenant_id, user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典。"""
|
||||
return {
|
||||
"active_strategy": self.active_strategy.value,
|
||||
"grayscale": {
|
||||
"enabled": self.grayscale.enabled,
|
||||
"percentage": self.grayscale.percentage,
|
||||
"allowlist": self.grayscale.allowlist,
|
||||
},
|
||||
"pipeline": {
|
||||
"top_k": self.pipeline.top_k,
|
||||
"score_threshold": self.pipeline.score_threshold,
|
||||
"min_hits": self.pipeline.min_hits,
|
||||
"two_stage_enabled": self.pipeline.two_stage_enabled,
|
||||
},
|
||||
"reranker": {
|
||||
"enabled": self.reranker.enabled,
|
||||
"model": self.reranker.model,
|
||||
"top_k_after_rerank": self.reranker.top_k_after_rerank,
|
||||
},
|
||||
"mode_router": {
|
||||
"runtime_mode": self.mode_router.runtime_mode.value,
|
||||
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
|
||||
},
|
||||
"metadata_inference": {
|
||||
"enabled": self.metadata_inference.enabled,
|
||||
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
|
||||
},
|
||||
"performance_thresholds": self.performance_thresholds,
|
||||
}
|
||||
|
||||
|
||||
_global_config: RetrievalStrategyConfig | None = None
|
||||
|
||||
|
||||
def get_strategy_config() -> RetrievalStrategyConfig:
|
||||
"""获取全局策略配置。"""
|
||||
global _global_config
|
||||
if _global_config is None:
|
||||
_global_config = RetrievalStrategyConfig()
|
||||
return _global_config
|
||||
|
||||
|
||||
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
|
||||
"""设置全局策略配置。"""
|
||||
global _global_config
|
||||
_global_config = config
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Default Pipeline.
|
||||
[AC-AISVC-RES-01] 默认策略 Pipeline,复用现有逻辑。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalResult
|
||||
from app.services.retrieval.optimized_retriever import OptimizedRetriever, get_optimized_retriever
|
||||
from app.services.retrieval.strategy.config import PipelineConfig
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultPipeline(BasePipeline):
|
||||
"""
|
||||
默认策略 Pipeline。【AC-AISVC-RES-01】
|
||||
|
||||
复用现有 OptimizedRetriever 逻辑,保持线上行为不变。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig | None = None,
|
||||
optimized_retriever: OptimizedRetriever | None = None,
|
||||
):
|
||||
self._config = config or PipelineConfig()
|
||||
self._optimized_retriever = optimized_retriever
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "default_pipeline"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "默认检索策略,复用现有 OptimizedRetriever 逻辑。"
|
||||
|
||||
async def _get_retriever(self) -> OptimizedRetriever:
|
||||
if self._optimized_retriever is None:
|
||||
self._optimized_retriever = await get_optimized_retriever()
|
||||
return self._optimized_retriever
|
||||
|
||||
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""执行默认检索流程。【AC-AISVC-RES-01】"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[DefaultPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
retriever = await self._get_retriever()
|
||||
|
||||
metadata_filter = None
|
||||
if ctx.metadata_filter:
|
||||
metadata_filter = ctx.metadata_filter.filter_dict
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
session_id=ctx.session_id,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(retrieval_ctx)
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"[DefaultPipeline] Retrieval completed: hits={len(result.hits)}, "
|
||||
f"latency_ms={latency_ms:.2f}"
|
||||
)
|
||||
|
||||
return PipelineResult(
|
||||
retrieval_result=result,
|
||||
pipeline_name=self.name,
|
||||
metadata_filter_applied=metadata_filter is not None,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={
|
||||
"retriever": "OptimizedRetriever",
|
||||
**(result.diagnostics or {}),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[DefaultPipeline] Retrieval error: {e}", exc_info=True)
|
||||
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
try:
|
||||
retriever = await self._get_retriever()
|
||||
return await retriever.health_check()
|
||||
except Exception as e:
|
||||
logger.error(f"[DefaultPipeline] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_default_pipeline: DefaultPipeline | None = None
|
||||
|
||||
|
||||
async def get_default_pipeline() -> DefaultPipeline:
|
||||
"""获取 DefaultPipeline 单例。"""
|
||||
global _default_pipeline
|
||||
if _default_pipeline is None:
|
||||
_default_pipeline = DefaultPipeline()
|
||||
return _default_pipeline
|
||||
|
|
@ -0,0 +1,364 @@
|
|||
"""
|
||||
Enhanced Pipeline.
|
||||
[AC-AISVC-RES-02] 增强策略 Pipeline,新端到端流程。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
from app.services.retrieval.base import RetrievalHit, RetrievalResult
|
||||
from app.services.retrieval.optimized_retriever import RRFCombiner
|
||||
from app.services.retrieval.strategy.config import (
|
||||
HybridRetrievalConfig,
|
||||
PipelineConfig,
|
||||
RerankerConfig,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalCandidate:
|
||||
"""检索候选结果。"""
|
||||
id: str
|
||||
text: str
|
||||
score: float
|
||||
vector_score: float = 0.0
|
||||
keyword_score: float = 0.0
|
||||
metadata: dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class EnhancedPipeline(BasePipeline):
|
||||
"""
|
||||
增强策略 Pipeline。【AC-AISVC-RES-02】
|
||||
|
||||
新端到端流程:
|
||||
1. Dense 向量检索
|
||||
2. Keyword 关键词检索
|
||||
3. RRF 融合排序
|
||||
4. 可选重排
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig | None = None,
|
||||
reranker_config: RerankerConfig | None = None,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
):
|
||||
self._config = config or PipelineConfig()
|
||||
self._reranker_config = reranker_config or RerankerConfig()
|
||||
self._qdrant_client = qdrant_client
|
||||
self._rrf_combiner = RRFCombiner(k=self._config.hybrid.rrf_k)
|
||||
self._embedding_provider: NomicEmbeddingProvider | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "enhanced_pipeline"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "增强检索策略,支持 Dense + Keyword + RRF 组合检索。"
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
from app.services.embedding.factory import get_embedding_config_manager
|
||||
manager = get_embedding_config_manager()
|
||||
provider = await manager.get_provider()
|
||||
if isinstance(provider, NomicEmbeddingProvider):
|
||||
self._embedding_provider = provider
|
||||
else:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""执行增强检索流程。【AC-AISVC-RES-02】"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[EnhancedPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = await self._get_embedding_provider()
|
||||
embedding_result = await provider.embed_query(ctx.query)
|
||||
|
||||
candidates = await self._hybrid_retrieve(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
embedding_result=embedding_result,
|
||||
metadata_filter=ctx.metadata_filter.filter_dict if ctx.metadata_filter else None,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
if self._reranker_config.enabled and ctx.use_reranker:
|
||||
candidates = await self._rerank(
|
||||
candidates=candidates,
|
||||
query=ctx.query,
|
||||
)
|
||||
|
||||
top_k = self._config.top_k
|
||||
final_candidates = candidates[:top_k]
|
||||
|
||||
hits = [
|
||||
RetrievalHit(
|
||||
text=c.text,
|
||||
score=c.score,
|
||||
source=self.name,
|
||||
metadata=c.metadata,
|
||||
)
|
||||
for c in final_candidates
|
||||
if c.score >= self._config.score_threshold
|
||||
]
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"[EnhancedPipeline] Retrieval completed: hits={len(hits)}, "
|
||||
f"latency_ms={latency_ms:.2f}"
|
||||
)
|
||||
|
||||
result = RetrievalResult(
|
||||
hits=hits,
|
||||
diagnostics={
|
||||
"total_candidates": len(candidates),
|
||||
"after_rerank": self._reranker_config.enabled and ctx.use_reranker,
|
||||
},
|
||||
)
|
||||
|
||||
return PipelineResult(
|
||||
retrieval_result=result,
|
||||
pipeline_name=self.name,
|
||||
used_reranker=self._reranker_config.enabled and ctx.use_reranker,
|
||||
metadata_filter_applied=ctx.metadata_filter is not None,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={
|
||||
"dense_weight": self._config.hybrid.dense_weight,
|
||||
"keyword_weight": self._config.hybrid.keyword_weight,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[EnhancedPipeline] Retrieval error: {e}", exc_info=True)
|
||||
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
|
||||
|
||||
async def _hybrid_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
embedding_result: Any,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[RetrievalCandidate]:
|
||||
"""混合检索:Dense + Keyword + RRF。"""
|
||||
client = await self._get_client()
|
||||
top_k = self._config.top_k
|
||||
expand_factor = self._config.hybrid.keyword_top_k_multiplier
|
||||
|
||||
vector_task = self._dense_search(
|
||||
client=client,
|
||||
tenant_id=tenant_id,
|
||||
embedding=embedding_result.embedding_full,
|
||||
top_k=top_k * expand_factor,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
|
||||
keyword_task = self._keyword_search(
|
||||
client=client,
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
top_k=top_k * expand_factor,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
) if self._config.hybrid.enable_keyword else asyncio.sleep(0, result=[])
|
||||
|
||||
vector_results, keyword_results = await asyncio.gather(
|
||||
vector_task, keyword_task, return_exceptions=True
|
||||
)
|
||||
|
||||
if isinstance(vector_results, Exception):
|
||||
logger.warning(f"[EnhancedPipeline] Dense search failed: {vector_results}")
|
||||
vector_results = []
|
||||
|
||||
if isinstance(keyword_results, Exception):
|
||||
logger.warning(f"[EnhancedPipeline] Keyword search failed: {keyword_results}")
|
||||
keyword_results = []
|
||||
|
||||
combined = self._rrf_combiner.combine(
|
||||
vector_results=vector_results,
|
||||
bm25_results=keyword_results,
|
||||
vector_weight=self._config.hybrid.dense_weight,
|
||||
bm25_weight=self._config.hybrid.keyword_weight,
|
||||
)
|
||||
|
||||
candidates = []
|
||||
for item in combined:
|
||||
candidates.append(RetrievalCandidate(
|
||||
id=item.get("id", ""),
|
||||
text=item.get("payload", {}).get("text", ""),
|
||||
score=item.get("score", 0.0),
|
||||
vector_score=item.get("vector_score", 0.0),
|
||||
keyword_score=item.get("bm25_score", 0.0),
|
||||
metadata=item.get("payload", {}),
|
||||
))
|
||||
|
||||
return candidates
|
||||
|
||||
async def _dense_search(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
embedding: list[float],
|
||||
top_k: int,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Dense 向量检索。"""
|
||||
try:
|
||||
results = await client.search(
|
||||
tenant_id=tenant_id,
|
||||
query_vector=embedding,
|
||||
limit=top_k,
|
||||
vector_name="full",
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"[EnhancedPipeline] Dense search error: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keyword 关键词检索。"""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
query_terms = set(re.findall(r'\w+', query.lower()))
|
||||
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=top_k * 3,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
scored_results = []
|
||||
for point in results[0]:
|
||||
text = point.payload.get("text", "").lower()
|
||||
text_terms = set(re.findall(r'\w+', text))
|
||||
overlap = len(query_terms & text_terms)
|
||||
|
||||
if overlap > 0:
|
||||
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
||||
scored_results.append({
|
||||
"id": str(point.id),
|
||||
"score": score,
|
||||
"payload": point.payload or {},
|
||||
})
|
||||
|
||||
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored_results[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[EnhancedPipeline] Keyword search failed: {e}")
|
||||
return []
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
candidates: list[RetrievalCandidate],
|
||||
query: str,
|
||||
) -> list[RetrievalCandidate]:
|
||||
"""可选重排。"""
|
||||
if not candidates:
|
||||
return candidates
|
||||
|
||||
try:
|
||||
provider = await self._get_embedding_provider()
|
||||
query_embedding = await provider.embed_query(query)
|
||||
|
||||
reranked = []
|
||||
for candidate in candidates:
|
||||
candidate_text = candidate.text[:500]
|
||||
if candidate_text:
|
||||
candidate_embedding = await provider.embed(candidate_text)
|
||||
similarity = self._cosine_similarity(
|
||||
query_embedding.embedding_full,
|
||||
candidate_embedding,
|
||||
)
|
||||
candidate.score = similarity
|
||||
|
||||
if candidate.score >= self._reranker_config.min_score_threshold:
|
||||
reranked.append(candidate)
|
||||
|
||||
reranked.sort(key=lambda x: x.score, reverse=True)
|
||||
return reranked[:self._reranker_config.top_k_after_rerank]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[EnhancedPipeline] Rerank failed: {e}")
|
||||
return candidates
|
||||
|
||||
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度。"""
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
qdrant = await client.get_client()
|
||||
await qdrant.get_collections()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[EnhancedPipeline] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_enhanced_pipeline: EnhancedPipeline | None = None
|
||||
|
||||
|
||||
async def get_enhanced_pipeline() -> EnhancedPipeline:
|
||||
"""获取 EnhancedPipeline 单例。"""
|
||||
global _enhanced_pipeline
|
||||
if _enhanced_pipeline is None:
|
||||
_enhanced_pipeline = EnhancedPipeline()
|
||||
return _enhanced_pipeline
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Metadata Inference Service.
|
||||
[AC-AISVC-RES-04] 元数据推断统一入口。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.mid.metadata_filter_builder import (
|
||||
FilterBuildResult,
|
||||
MetadataFilterBuilder,
|
||||
)
|
||||
from app.services.retrieval.strategy.config import (
|
||||
FilterMode,
|
||||
MetadataInferenceConfig,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import MetadataFilterResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceContext:
|
||||
"""元数据推断上下文。"""
|
||||
tenant_id: str
|
||||
query: str
|
||||
session_id: str | None = None
|
||||
user_id: str | None = None
|
||||
channel_type: str | None = None
|
||||
existing_context: dict[str, Any] = field(default_factory=dict)
|
||||
slot_state: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
"""元数据推断结果。"""
|
||||
filter_result: MetadataFilterResult
|
||||
inferred_fields: dict[str, Any] = field(default_factory=dict)
|
||||
confidence_scores: dict[str, float] = field(default_factory=dict)
|
||||
overall_confidence: float | None = None
|
||||
inference_source: str = "unknown"
|
||||
|
||||
|
||||
class MetadataInferenceService:
|
||||
"""
|
||||
元数据推断统一入口。【AC-AISVC-RES-04】
|
||||
|
||||
职责:
|
||||
1. 统一的元数据推断入口(策略无关)
|
||||
2. 根据置信度决定 hard/soft filter 模式
|
||||
3. 与现有 MetadataFilterBuilder 保持一致
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: MetadataInferenceConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._config = config or MetadataInferenceConfig()
|
||||
self._filter_builder: MetadataFilterBuilder | None = None
|
||||
|
||||
async def infer(self, ctx: InferenceContext) -> InferenceResult:
|
||||
"""执行元数据推断。【AC-AISVC-RES-04】"""
|
||||
logger.info(
|
||||
f"[MetadataInference] Starting inference: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
effective_context = dict(ctx.existing_context)
|
||||
|
||||
if ctx.slot_state:
|
||||
effective_context = await self._merge_slot_state(
|
||||
effective_context, ctx.slot_state
|
||||
)
|
||||
|
||||
build_result = await self._filter_builder.build_filter(
|
||||
tenant_id=ctx.tenant_id,
|
||||
context=effective_context,
|
||||
)
|
||||
|
||||
confidence = self._calculate_confidence(build_result, effective_context)
|
||||
filter_mode = self._config.determine_filter_mode(confidence)
|
||||
|
||||
filter_result = MetadataFilterResult(
|
||||
filter_dict=build_result.applied_filter,
|
||||
filter_mode=filter_mode,
|
||||
confidence=confidence,
|
||||
missing_required_slots=build_result.missing_required_slots,
|
||||
debug_info=build_result.debug_info,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataInference] Inference completed: filter_mode={filter_mode.value}, "
|
||||
f"confidence={confidence}"
|
||||
)
|
||||
|
||||
return InferenceResult(
|
||||
filter_result=filter_result,
|
||||
inferred_fields=build_result.applied_filter,
|
||||
overall_confidence=confidence,
|
||||
inference_source="metadata_filter_builder",
|
||||
)
|
||||
|
||||
async def _merge_slot_state(
|
||||
self, context: dict[str, Any], slot_state: Any
|
||||
) -> dict[str, Any]:
|
||||
"""合并槽位状态到上下文。"""
|
||||
if hasattr(slot_state, 'filled_slots'):
|
||||
for slot_key, slot_value in slot_state.filled_slots.items():
|
||||
if slot_key not in context:
|
||||
context[slot_key] = slot_value
|
||||
return context
|
||||
|
||||
def _calculate_confidence(
|
||||
self, build_result: FilterBuildResult, context: dict[str, Any]
|
||||
) -> float | None:
|
||||
"""计算推断置信度。"""
|
||||
if build_result.missing_required_slots:
|
||||
return 0.3
|
||||
if not build_result.applied_filter:
|
||||
return None
|
||||
if not context:
|
||||
return 0.5
|
||||
applied_ratio = len(build_result.applied_filter) / max(len(context), 1)
|
||||
if applied_ratio >= 0.8:
|
||||
return 0.9
|
||||
elif applied_ratio >= 0.5:
|
||||
return 0.7
|
||||
return 0.5
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Mode Router.
|
||||
[AC-AISVC-RES-09~15] 模式路由器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.retrieval.strategy.config import ModeRouterConfig, RuntimeMode
|
||||
from app.services.retrieval.strategy.pipeline_base import PipelineResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModeDecision:
|
||||
"""模式决策结果。"""
|
||||
mode: RuntimeMode
|
||||
reason: str
|
||||
confidence: float | None = None
|
||||
complexity_score: float | None = None
|
||||
|
||||
|
||||
class ModeRouter:
|
||||
"""
|
||||
模式路由器。【AC-AISVC-RES-09~15】
|
||||
|
||||
职责:
|
||||
1. 根据 rag_runtime_mode 选择 direct/react/auto 模式
|
||||
2. auto 模式下根据复杂度与置信度自动选择路由
|
||||
3. direct 低置信度时触发 react 回退
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModeRouterConfig | None = None):
|
||||
self._config = config or ModeRouterConfig()
|
||||
|
||||
def decide(
|
||||
self,
|
||||
query: str,
|
||||
confidence: float | None = None,
|
||||
complexity_score: float | None = None,
|
||||
) -> ModeDecision:
|
||||
"""决定使用哪种模式。【AC-AISVC-RES-09~13】"""
|
||||
if self._config.runtime_mode == RuntimeMode.REACT:
|
||||
return ModeDecision(mode=RuntimeMode.REACT, reason="runtime_mode=react")
|
||||
|
||||
if self._config.runtime_mode == RuntimeMode.DIRECT:
|
||||
return ModeDecision(mode=RuntimeMode.DIRECT, reason="runtime_mode=direct")
|
||||
|
||||
calculated_complexity = complexity_score or self._calculate_complexity(query)
|
||||
|
||||
if self._should_use_direct(query, confidence, calculated_complexity):
|
||||
return ModeDecision(
|
||||
mode=RuntimeMode.DIRECT,
|
||||
reason="auto: short_query_high_confidence",
|
||||
confidence=confidence,
|
||||
complexity_score=calculated_complexity,
|
||||
)
|
||||
|
||||
return ModeDecision(
|
||||
mode=RuntimeMode.REACT,
|
||||
reason="auto: complex_or_low_confidence",
|
||||
confidence=confidence,
|
||||
complexity_score=calculated_complexity,
|
||||
)
|
||||
|
||||
def should_fallback_to_react(self, direct_result: PipelineResult) -> bool:
|
||||
"""判断是否应该从 direct 回退到 react。【AC-AISVC-RES-14】"""
|
||||
if not self._config.direct_fallback_on_low_confidence:
|
||||
return False
|
||||
if direct_result.is_empty:
|
||||
return True
|
||||
max_score = direct_result.retrieval_result.max_score
|
||||
if max_score < 0.3:
|
||||
return True
|
||||
if direct_result.retrieval_result.hit_count < 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_use_direct(
|
||||
self, query: str, confidence: float | None, complexity_score: float
|
||||
) -> bool:
|
||||
if len(query) <= self._config.short_query_threshold:
|
||||
if confidence and confidence >= self._config.react_trigger_confidence_threshold:
|
||||
return True
|
||||
if confidence and confidence < self._config.react_trigger_confidence_threshold:
|
||||
return False
|
||||
if complexity_score >= self._config.react_trigger_complexity_score:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _calculate_complexity(self, query: str) -> float:
|
||||
score = 0.0
|
||||
if len(query) > 50:
|
||||
score += 0.2
|
||||
if len(query) > 100:
|
||||
score += 0.2
|
||||
condition_words = ["和", "或", "但是", "如果", "同时", "并且", "或者", "以及"]
|
||||
for word in condition_words:
|
||||
if word in query:
|
||||
score += 0.1
|
||||
return min(score, 1.0)
|
||||
|
||||
def get_config(self) -> ModeRouterConfig:
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: ModeRouterConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
|
||||
_mode_router: ModeRouter | None = None
|
||||
|
||||
|
||||
def get_mode_router() -> ModeRouter:
|
||||
global _mode_router
|
||||
if _mode_router is None:
|
||||
_mode_router = ModeRouter()
|
||||
return _mode_router
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
Pipeline Base Classes.
|
||||
[AC-AISVC-RES-01~15] Pipeline 抽象基类。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
||||
from app.services.retrieval.strategy.config import FilterMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFilterResult:
|
||||
"""元数据过滤结果。"""
|
||||
filter_dict: dict[str, Any] = field(default_factory=dict)
|
||||
filter_mode: FilterMode = FilterMode.NONE
|
||||
confidence: float | None = None
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
debug_info: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
"""Pipeline 执行上下文。"""
|
||||
retrieval_ctx: RetrievalContext
|
||||
metadata_filter: MetadataFilterResult | None = None
|
||||
use_reranker: bool = False
|
||||
use_react: bool = False
|
||||
react_iteration: int = 0
|
||||
previous_results: list[RetrievalHit] = field(default_factory=list)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.retrieval_ctx.tenant_id
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.retrieval_ctx.query
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
return self.retrieval_ctx.session_id
|
||||
|
||||
@property
|
||||
def kb_ids(self) -> list[str] | None:
|
||||
return self.retrieval_ctx.kb_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Pipeline 执行结果。"""
|
||||
retrieval_result: RetrievalResult
|
||||
pipeline_name: str = ""
|
||||
used_reranker: bool = False
|
||||
used_react: bool = False
|
||||
react_iterations: int = 0
|
||||
metadata_filter_applied: bool = False
|
||||
fallback_triggered: bool = False
|
||||
fallback_reason: str | None = None
|
||||
latency_ms: float = 0.0
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def hits(self) -> list[RetrievalHit]:
|
||||
return self.retrieval_result.hits
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return self.retrieval_result.is_empty
|
||||
|
||||
|
||||
class BasePipeline(ABC):
|
||||
"""Pipeline 抽象基类。【AC-AISVC-RES-01, AC-AISVC-RES-02】"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Pipeline 名称。"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Pipeline 描述。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""执行检索。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
pass
|
||||
|
||||
def _create_empty_result(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
error: str | None = None,
|
||||
latency_ms: float = 0.0,
|
||||
) -> PipelineResult:
|
||||
"""创建空结果。"""
|
||||
diagnostics = {"error": error} if error else {}
|
||||
return PipelineResult(
|
||||
retrieval_result=RetrievalResult(hits=[], diagnostics=diagnostics),
|
||||
pipeline_name=self.name,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
"""
|
||||
Retrieval Strategy - Unified Entry Point.
|
||||
[AC-AISVC-RES-01~15] 检索策略统一入口。
|
||||
|
||||
整合:
|
||||
- StrategyRouter: 策略路由(default/enhanced)
|
||||
- ModeRouter: 模式路由(direct/react/auto)
|
||||
- MetadataInferenceService: 元数据推断
|
||||
- RollbackManager: 回退管理
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalResult
|
||||
from app.services.retrieval.strategy.config import (
|
||||
RetrievalStrategyConfig,
|
||||
RuntimeMode,
|
||||
StrategyType,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
|
||||
from app.services.retrieval.strategy.metadata_inference import (
|
||||
InferenceContext,
|
||||
MetadataInferenceService,
|
||||
)
|
||||
from app.services.retrieval.strategy.mode_router import ModeDecision, ModeRouter
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
MetadataFilterResult,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
from app.services.retrieval.strategy.rollback_manager import (
|
||||
RollbackManager,
|
||||
RollbackTrigger,
|
||||
)
|
||||
from app.services.retrieval.strategy.strategy_router import (
|
||||
RoutingDecision,
|
||||
StrategyRouter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalStrategyResult:
|
||||
"""检索策略执行结果。"""
|
||||
|
||||
retrieval_result: RetrievalResult
|
||||
strategy_used: StrategyType
|
||||
mode_used: RuntimeMode
|
||||
metadata_filter: MetadataFilterResult | None
|
||||
latency_ms: float
|
||||
diagnostics: dict[str, Any]
|
||||
|
||||
|
||||
class RetrievalStrategy:
|
||||
"""
|
||||
检索策略统一入口。【AC-AISVC-RES-01~15】
|
||||
|
||||
整合所有策略组件:
|
||||
1. 元数据推断(MetadataInferenceService)
|
||||
2. 策略路由(StrategyRouter)
|
||||
3. 模式路由(ModeRouter)
|
||||
4. 回退管理(RollbackManager)
|
||||
|
||||
使用方式:
|
||||
```python
|
||||
strategy = RetrievalStrategy(session)
|
||||
result = await strategy.retrieve(
|
||||
tenant_id="tenant_1",
|
||||
query="用户问题",
|
||||
context={"user_id": "user_1"},
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._config = config or RetrievalStrategyConfig()
|
||||
|
||||
self._strategy_router = StrategyRouter(self._config)
|
||||
self._mode_router = ModeRouter(self._config.mode_router)
|
||||
self._rollback_manager = RollbackManager(self._config)
|
||||
self._metadata_inference: MetadataInferenceService | None = None
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
use_reranker: bool = False,
|
||||
use_react: bool = False,
|
||||
) -> RetrievalStrategyResult:
|
||||
"""
|
||||
执行检索策略。【AC-AISVC-RES-01~15】
|
||||
|
||||
流程:
|
||||
1. 元数据推断
|
||||
2. 策略路由
|
||||
3. 模式路由
|
||||
4. 执行检索
|
||||
5. 检查是否需要回退
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
query: 查询文本
|
||||
context: 上下文信息
|
||||
kb_ids: 知识库 ID 列表
|
||||
session_id: 会话 ID
|
||||
user_id: 用户 ID
|
||||
use_reranker: 是否使用重排
|
||||
use_react: 是否使用 ReAct 模式
|
||||
|
||||
Returns:
|
||||
RetrievalStrategyResult 包含检索结果和诊断信息
|
||||
"""
|
||||
start_time = time.time()
|
||||
context = context or {}
|
||||
|
||||
logger.info(
|
||||
f"[RetrievalStrategy] Starting retrieval: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
metadata_filter = await self._infer_metadata(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
routing_decision = await self._strategy_router.route(tenant_id, user_id)
|
||||
|
||||
mode_decision = self._mode_router.decide(
|
||||
query=query,
|
||||
confidence=metadata_filter.confidence if metadata_filter else None,
|
||||
)
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
metadata_filter=metadata_filter.filter_dict if metadata_filter else None,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
|
||||
pipeline_ctx = PipelineContext(
|
||||
retrieval_ctx=retrieval_ctx,
|
||||
metadata_filter=metadata_filter,
|
||||
use_reranker=use_reranker or self._config.reranker.enabled,
|
||||
use_react=use_react or mode_decision.mode == RuntimeMode.REACT,
|
||||
)
|
||||
|
||||
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
|
||||
|
||||
if mode_decision.mode == RuntimeMode.DIRECT and self._mode_router.should_fallback_to_react(pipeline_result):
|
||||
logger.info("[RetrievalStrategy] Falling back to react mode")
|
||||
pipeline_ctx.use_react = True
|
||||
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
|
||||
mode_decision = ModeDecision(
|
||||
mode=RuntimeMode.REACT,
|
||||
reason="fallback_from_direct",
|
||||
)
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
self._check_performance(latency_ms, tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[RetrievalStrategy] Retrieval completed: strategy={routing_decision.strategy.value}, "
|
||||
f"mode={mode_decision.mode.value}, hits={len(pipeline_result.hits)}, "
|
||||
f"latency_ms={latency_ms:.2f}"
|
||||
)
|
||||
|
||||
return RetrievalStrategyResult(
|
||||
retrieval_result=pipeline_result.retrieval_result,
|
||||
strategy_used=routing_decision.strategy,
|
||||
mode_used=mode_decision.mode,
|
||||
metadata_filter=metadata_filter,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={
|
||||
"routing_reason": routing_decision.reason,
|
||||
"mode_reason": mode_decision.reason,
|
||||
"grayscale_hit": routing_decision.grayscale_hit,
|
||||
**pipeline_result.diagnostics,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"[RetrievalStrategy] Retrieval error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
self._rollback_manager.rollback(
|
||||
trigger=RollbackTrigger.ERROR,
|
||||
reason=str(e),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return RetrievalStrategyResult(
|
||||
retrieval_result=RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e)},
|
||||
),
|
||||
strategy_used=StrategyType.DEFAULT,
|
||||
mode_used=RuntimeMode.DIRECT,
|
||||
metadata_filter=None,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={"error": str(e)},
|
||||
)
|
||||
|
||||
async def _infer_metadata(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
context: dict[str, Any],
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> MetadataFilterResult | None:
|
||||
"""执行元数据推断。"""
|
||||
try:
|
||||
if self._metadata_inference is None:
|
||||
self._metadata_inference = MetadataInferenceService(
|
||||
self._session,
|
||||
self._config.metadata_inference,
|
||||
)
|
||||
|
||||
inference_ctx = InferenceContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
existing_context=context,
|
||||
)
|
||||
|
||||
result = await self._metadata_inference.infer(inference_ctx)
|
||||
return result.filter_result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[RetrievalStrategy] Metadata inference failed: {e}")
|
||||
return None
|
||||
|
||||
def _check_performance(self, latency_ms: float, tenant_id: str | None) -> None:
|
||||
"""检查性能指标,必要时触发回退。"""
|
||||
self._rollback_manager.check_and_rollback(
|
||||
metrics={"latency_ms": latency_ms},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def get_config(self) -> RetrievalStrategyConfig:
|
||||
"""获取当前配置。"""
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||
"""更新配置。"""
|
||||
self._config = config
|
||||
self._strategy_router.update_config(config)
|
||||
self._mode_router.update_config(config.mode_router)
|
||||
self._rollback_manager.update_config(config)
|
||||
|
||||
async def health_check(self) -> dict[str, bool]:
|
||||
"""健康检查。"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
default_pipeline = await self._strategy_router._get_default_pipeline()
|
||||
results["default_pipeline"] = await default_pipeline.health_check()
|
||||
except Exception:
|
||||
results["default_pipeline"] = False
|
||||
|
||||
try:
|
||||
enhanced_pipeline = await self._strategy_router._get_enhanced_pipeline()
|
||||
results["enhanced_pipeline"] = await enhanced_pipeline.health_check()
|
||||
except Exception:
|
||||
results["enhanced_pipeline"] = False
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def create_retrieval_strategy(
|
||||
session: AsyncSession,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
) -> RetrievalStrategy:
|
||||
"""创建 RetrievalStrategy 实例。"""
|
||||
return RetrievalStrategy(session, config)
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
"""
|
||||
Rollback Manager.
|
||||
[AC-AISVC-RES-07] 策略回退与审计管理器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.strategy.config import RetrievalStrategyConfig, StrategyType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RollbackTrigger(str, Enum):
|
||||
"""回退触发原因。"""
|
||||
MANUAL = "manual"
|
||||
ERROR = "error"
|
||||
PERFORMANCE = "performance"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditLog:
|
||||
"""审计日志记录。"""
|
||||
timestamp: str
|
||||
action: str
|
||||
from_strategy: str
|
||||
to_strategy: str
|
||||
trigger: str
|
||||
reason: str
|
||||
tenant_id: str | None = None
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackResult:
|
||||
"""回退结果。"""
|
||||
success: bool
|
||||
previous_strategy: StrategyType
|
||||
current_strategy: StrategyType
|
||||
trigger: RollbackTrigger
|
||||
reason: str
|
||||
audit_log: AuditLog | None = None
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
策略回退管理器。【AC-AISVC-RES-07】
|
||||
|
||||
职责:
|
||||
1. 策略异常时回退到默认策略
|
||||
2. 记录审计日志
|
||||
3. 支持手动触发回退
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
max_audit_logs: int = 1000,
|
||||
):
|
||||
self._config = config or RetrievalStrategyConfig()
|
||||
self._max_audit_logs = max_audit_logs
|
||||
self._audit_logs: list[AuditLog] = []
|
||||
self._previous_strategy: StrategyType = StrategyType.DEFAULT
|
||||
|
||||
def rollback(
|
||||
self,
|
||||
trigger: RollbackTrigger,
|
||||
reason: str,
|
||||
tenant_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> RollbackResult:
|
||||
"""执行策略回退。【AC-AISVC-RES-07】"""
|
||||
previous = self._config.active_strategy
|
||||
current = StrategyType.DEFAULT
|
||||
|
||||
if previous == StrategyType.DEFAULT:
|
||||
return RollbackResult(
|
||||
success=False,
|
||||
previous_strategy=previous,
|
||||
current_strategy=current,
|
||||
trigger=trigger,
|
||||
reason="Already on default strategy",
|
||||
)
|
||||
|
||||
self._previous_strategy = previous
|
||||
self._config.active_strategy = current
|
||||
|
||||
audit_log = AuditLog(
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
action="rollback",
|
||||
from_strategy=previous.value,
|
||||
to_strategy=current.value,
|
||||
trigger=trigger.value,
|
||||
reason=reason,
|
||||
tenant_id=tenant_id,
|
||||
details=details or {},
|
||||
)
|
||||
|
||||
self._add_audit_log(audit_log)
|
||||
|
||||
logger.info(
|
||||
f"[RollbackManager] Rollback executed: from={previous.value}, "
|
||||
f"to={current.value}, trigger={trigger.value}"
|
||||
)
|
||||
|
||||
return RollbackResult(
|
||||
success=True,
|
||||
previous_strategy=previous,
|
||||
current_strategy=current,
|
||||
trigger=trigger,
|
||||
reason=reason,
|
||||
audit_log=audit_log,
|
||||
)
|
||||
|
||||
def check_and_rollback(
|
||||
self, metrics: dict[str, float], tenant_id: str | None = None
|
||||
) -> RollbackResult | None:
|
||||
"""检查性能指标并自动回退。【AC-AISVC-RES-08】"""
|
||||
thresholds = self._config.performance_thresholds
|
||||
|
||||
latency = metrics.get("latency_ms", 0)
|
||||
if latency > thresholds.get("max_latency_ms", 2000):
|
||||
return self.rollback(
|
||||
trigger=RollbackTrigger.PERFORMANCE,
|
||||
reason=f"Latency {latency}ms exceeds threshold",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
error_rate = metrics.get("error_rate", 0)
|
||||
if error_rate > thresholds.get("max_error_rate", 0.05):
|
||||
return self.rollback(
|
||||
trigger=RollbackTrigger.ERROR,
|
||||
reason=f"Error rate {error_rate} exceeds threshold",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _add_audit_log(self, log: AuditLog) -> None:
|
||||
self._audit_logs.append(log)
|
||||
if len(self._audit_logs) > self._max_audit_logs:
|
||||
self._audit_logs = self._audit_logs[-self._max_audit_logs:]
|
||||
|
||||
def record_audit(
|
||||
self,
|
||||
action: str,
|
||||
details: dict[str, Any],
|
||||
tenant_id: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""记录审计日志。"""
|
||||
audit_log = AuditLog(
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
action=action,
|
||||
from_strategy=self._config.active_strategy.value,
|
||||
to_strategy=self._config.active_strategy.value,
|
||||
trigger="n/a",
|
||||
reason=details.get("reason", ""),
|
||||
tenant_id=tenant_id,
|
||||
details=details,
|
||||
)
|
||||
|
||||
self._add_audit_log(audit_log)
|
||||
|
||||
logger.info(
|
||||
f"[RollbackManager] Audit recorded: action={action}, "
|
||||
f"strategy={self._config.active_strategy.value}"
|
||||
)
|
||||
|
||||
return audit_log
|
||||
|
||||
def get_audit_logs(self, limit: int = 100) -> list[AuditLog]:
|
||||
return self._audit_logs[-limit:]
|
||||
|
||||
def get_config(self) -> RetrievalStrategyConfig:
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
|
||||
_rollback_manager: RollbackManager | None = None
|
||||
|
||||
|
||||
def get_rollback_manager() -> RollbackManager:
|
||||
global _rollback_manager
|
||||
if _rollback_manager is None:
|
||||
_rollback_manager = RollbackManager()
|
||||
return _rollback_manager
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
Strategy Router.
|
||||
[AC-AISVC-RES-01~03] 策略路由器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.strategy.config import (
|
||||
RetrievalStrategyConfig,
|
||||
StrategyType,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline, get_default_pipeline
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline, get_enhanced_pipeline
|
||||
from app.services.retrieval.strategy.pipeline_base import BasePipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""路由决策结果。"""
|
||||
strategy: StrategyType
|
||||
pipeline: BasePipeline
|
||||
reason: str
|
||||
grayscale_hit: bool = False
|
||||
|
||||
|
||||
class StrategyRouter:
|
||||
"""
|
||||
策略路由器。【AC-AISVC-RES-01~03】
|
||||
|
||||
职责:
|
||||
1. 根据配置选择默认策略或增强策略
|
||||
2. 支持灰度发布(percentage/allowlist)
|
||||
3. 不影响正在运行的默认策略请求
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
default_pipeline: DefaultPipeline | None = None,
|
||||
enhanced_pipeline: EnhancedPipeline | None = None,
|
||||
):
|
||||
self._config = config or RetrievalStrategyConfig()
|
||||
self._default_pipeline = default_pipeline
|
||||
self._enhanced_pipeline = enhanced_pipeline
|
||||
|
||||
async def route(
|
||||
self, tenant_id: str, user_id: str | None = None
|
||||
) -> RoutingDecision:
|
||||
"""路由到合适的策略。【AC-AISVC-RES-01~03】"""
|
||||
if self._config.active_strategy == StrategyType.ENHANCED:
|
||||
pipeline = await self._get_enhanced_pipeline()
|
||||
return RoutingDecision(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
pipeline=pipeline,
|
||||
reason="active_strategy=enhanced",
|
||||
)
|
||||
|
||||
if self._config.grayscale.should_use_enhanced(tenant_id, user_id):
|
||||
pipeline = await self._get_enhanced_pipeline()
|
||||
return RoutingDecision(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
pipeline=pipeline,
|
||||
reason="grayscale_hit",
|
||||
grayscale_hit=True,
|
||||
)
|
||||
|
||||
pipeline = await self._get_default_pipeline()
|
||||
return RoutingDecision(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
pipeline=pipeline,
|
||||
reason="default_strategy",
|
||||
)
|
||||
|
||||
async def _get_default_pipeline(self) -> DefaultPipeline:
|
||||
if self._default_pipeline is None:
|
||||
self._default_pipeline = await get_default_pipeline()
|
||||
return self._default_pipeline
|
||||
|
||||
async def _get_enhanced_pipeline(self) -> EnhancedPipeline:
|
||||
if self._enhanced_pipeline is None:
|
||||
self._enhanced_pipeline = await get_enhanced_pipeline()
|
||||
return self._enhanced_pipeline
|
||||
|
||||
def get_config(self) -> RetrievalStrategyConfig:
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||
self._config = config
|
||||
logger.info(f"[StrategyRouter] Config updated: strategy={config.active_strategy.value}")
|
||||
|
||||
|
||||
_strategy_router: StrategyRouter | None = None
|
||||
|
||||
|
||||
def get_strategy_router() -> StrategyRouter:
|
||||
global _strategy_router
|
||||
if _strategy_router is None:
|
||||
_strategy_router = StrategyRouter()
|
||||
return _strategy_router
|
||||
|
||||
|
||||
def set_strategy_router(router: StrategyRouter) -> None:
|
||||
"""Set the global strategy router instance."""
|
||||
global _strategy_router
|
||||
_strategy_router = router
|
||||
|
|
@ -0,0 +1,645 @@
|
|||
"""
|
||||
Unit tests for Retrieval Strategy Module.
|
||||
[AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.services.retrieval.strategy.config import (
|
||||
FilterMode,
|
||||
GrayscaleConfig,
|
||||
HybridRetrievalConfig,
|
||||
MetadataInferenceConfig,
|
||||
ModeRouterConfig,
|
||||
PipelineConfig,
|
||||
RerankerConfig,
|
||||
RetrievalStrategyConfig,
|
||||
RuntimeMode,
|
||||
StrategyType,
|
||||
get_strategy_config,
|
||||
set_strategy_config,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
MetadataFilterResult,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
|
||||
from app.services.retrieval.strategy.strategy_router import (
|
||||
RoutingDecision,
|
||||
StrategyRouter,
|
||||
get_strategy_router,
|
||||
)
|
||||
from app.services.retrieval.strategy.mode_router import (
|
||||
ModeDecision,
|
||||
ModeRouter,
|
||||
get_mode_router,
|
||||
)
|
||||
from app.services.retrieval.strategy.rollback_manager import (
|
||||
AuditLog,
|
||||
RollbackManager,
|
||||
RollbackResult,
|
||||
RollbackTrigger,
|
||||
get_rollback_manager,
|
||||
)
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
||||
|
||||
|
||||
class TestStrategyConfig:
|
||||
"""[AC-AISVC-RES-01~15] Tests for strategy configuration models."""
|
||||
|
||||
def test_strategy_type_enum(self):
|
||||
"""[AC-AISVC-RES-01] Strategy type should have default and enhanced values."""
|
||||
assert StrategyType.DEFAULT.value == "default"
|
||||
assert StrategyType.ENHANCED.value == "enhanced"
|
||||
|
||||
def test_runtime_mode_enum(self):
|
||||
"""[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values."""
|
||||
assert RuntimeMode.DIRECT.value == "direct"
|
||||
assert RuntimeMode.REACT.value == "react"
|
||||
assert RuntimeMode.AUTO.value == "auto"
|
||||
|
||||
def test_filter_mode_enum(self):
|
||||
"""[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values."""
|
||||
assert FilterMode.HARD.value == "hard"
|
||||
assert FilterMode.SOFT.value == "soft"
|
||||
assert FilterMode.NONE.value == "none"
|
||||
|
||||
def test_grayscale_config_default(self):
|
||||
"""[AC-AISVC-RES-03] Default grayscale config should be disabled."""
|
||||
config = GrayscaleConfig()
|
||||
assert config.enabled is False
|
||||
assert config.percentage == 0.0
|
||||
assert config.allowlist == []
|
||||
|
||||
def test_grayscale_config_should_use_enhanced_disabled(self):
|
||||
"""[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled."""
|
||||
config = GrayscaleConfig(enabled=False, percentage=50.0)
|
||||
assert config.should_use_enhanced("tenant_a") is False
|
||||
|
||||
def test_grayscale_config_should_use_enhanced_allowlist(self):
|
||||
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
|
||||
config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"])
|
||||
assert config.should_use_enhanced("tenant_a") is True
|
||||
assert config.should_use_enhanced("tenant_b") is True
|
||||
assert config.should_use_enhanced("tenant_c") is False
|
||||
|
||||
def test_grayscale_config_should_use_enhanced_percentage(self):
|
||||
"""[AC-AISVC-RES-03] Should use enhanced based on percentage."""
|
||||
config = GrayscaleConfig(enabled=True, percentage=100.0)
|
||||
assert config.should_use_enhanced("any_tenant") is True
|
||||
|
||||
config = GrayscaleConfig(enabled=True, percentage=0.0)
|
||||
assert config.should_use_enhanced("any_tenant") is False
|
||||
|
||||
def test_reranker_config_default(self):
|
||||
"""[AC-AISVC-RES-08] Default reranker config should be disabled."""
|
||||
config = RerankerConfig()
|
||||
assert config.enabled is False
|
||||
assert config.model == "cross-encoder"
|
||||
assert config.top_k_after_rerank == 5
|
||||
|
||||
def test_mode_router_config_default(self):
|
||||
"""[AC-AISVC-RES-09] Default mode router config should be direct."""
|
||||
config = ModeRouterConfig()
|
||||
assert config.runtime_mode == RuntimeMode.DIRECT
|
||||
assert config.react_trigger_confidence_threshold == 0.6
|
||||
assert config.react_max_steps == 5
|
||||
|
||||
def test_mode_router_config_should_use_react_always(self):
|
||||
"""[AC-AISVC-RES-10] React mode should always use react."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
|
||||
assert config.should_use_react("any query") is True
|
||||
|
||||
def test_mode_router_config_should_use_react_never(self):
|
||||
"""[AC-AISVC-RES-09] Direct mode should never use react."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT)
|
||||
assert config.should_use_react("any query") is False
|
||||
|
||||
def test_mode_router_config_auto_short_query_high_confidence(self):
|
||||
"""[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||
assert config.should_use_react("短问题", confidence=0.8) is False
|
||||
|
||||
def test_mode_router_config_auto_low_confidence(self):
|
||||
"""[AC-AISVC-RES-13] Auto mode with low confidence should use react."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||
assert config.should_use_react("any query", confidence=0.3) is True
|
||||
|
||||
def test_metadata_inference_config_determine_filter_mode(self):
|
||||
"""[AC-AISVC-RES-04] Should determine filter mode based on confidence."""
|
||||
config = MetadataInferenceConfig()
|
||||
|
||||
assert config.determine_filter_mode(0.9) == FilterMode.HARD
|
||||
assert config.determine_filter_mode(0.6) == FilterMode.SOFT
|
||||
assert config.determine_filter_mode(0.3) == FilterMode.NONE
|
||||
assert config.determine_filter_mode(None) == FilterMode.NONE
|
||||
|
||||
def test_pipeline_config_default(self):
|
||||
"""[AC-AISVC-RES-01] Default pipeline config should have sensible defaults."""
|
||||
config = PipelineConfig()
|
||||
assert config.top_k == 5
|
||||
assert config.score_threshold == 0.01
|
||||
assert config.two_stage_enabled is True
|
||||
|
||||
def test_retrieval_strategy_config_default(self):
|
||||
"""[AC-AISVC-RES-01] Default strategy config should use default strategy."""
|
||||
config = RetrievalStrategyConfig()
|
||||
assert config.active_strategy == StrategyType.DEFAULT
|
||||
assert config.grayscale.enabled is False
|
||||
assert config.mode_router.runtime_mode == RuntimeMode.DIRECT
|
||||
|
||||
def test_retrieval_strategy_config_is_enhanced_enabled(self):
|
||||
"""[AC-AISVC-RES-02] Should check if enhanced is enabled."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
assert config.is_enhanced_enabled("tenant_a") is True
|
||||
|
||||
config = RetrievalStrategyConfig(
|
||||
active_strategy=StrategyType.DEFAULT,
|
||||
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
|
||||
)
|
||||
assert config.is_enhanced_enabled("tenant_a") is True
|
||||
assert config.is_enhanced_enabled("tenant_b") is False
|
||||
|
||||
def test_retrieval_strategy_config_to_dict(self):
|
||||
"""[AC-AISVC-RES-01] Should convert config to dictionary."""
|
||||
config = RetrievalStrategyConfig()
|
||||
d = config.to_dict()
|
||||
|
||||
assert d["active_strategy"] == "default"
|
||||
assert "grayscale" in d
|
||||
assert "pipeline" in d
|
||||
assert "reranker" in d
|
||||
assert "mode_router" in d
|
||||
|
||||
def test_global_config_functions(self):
|
||||
"""[AC-AISVC-RES-01] Should get and set global config."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
set_strategy_config(config)
|
||||
|
||||
retrieved = get_strategy_config()
|
||||
assert retrieved.active_strategy == StrategyType.ENHANCED
|
||||
|
||||
set_strategy_config(RetrievalStrategyConfig())
|
||||
|
||||
|
||||
class TestPipelineBase:
|
||||
"""[AC-AISVC-RES-01~02] Tests for pipeline base classes."""
|
||||
|
||||
def test_metadata_filter_result_default(self):
|
||||
"""[AC-AISVC-RES-04] Default metadata filter result should be empty."""
|
||||
result = MetadataFilterResult()
|
||||
assert result.filter_dict == {}
|
||||
assert result.filter_mode == FilterMode.NONE
|
||||
assert result.confidence is None
|
||||
|
||||
def test_pipeline_context_properties(self):
|
||||
"""[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties."""
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id="tenant_1",
|
||||
query="test query",
|
||||
session_id="session_1",
|
||||
kb_ids=["kb_1"],
|
||||
)
|
||||
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
||||
|
||||
assert pipeline_ctx.tenant_id == "tenant_1"
|
||||
assert pipeline_ctx.query == "test query"
|
||||
assert pipeline_ctx.session_id == "session_1"
|
||||
assert pipeline_ctx.kb_ids == ["kb_1"]
|
||||
|
||||
def test_pipeline_result_properties(self):
|
||||
"""[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties."""
|
||||
hits = [
|
||||
RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}),
|
||||
RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}),
|
||||
]
|
||||
retrieval_result = RetrievalResult(hits=hits)
|
||||
pipeline_result = PipelineResult(
|
||||
retrieval_result=retrieval_result,
|
||||
pipeline_name="test_pipeline",
|
||||
)
|
||||
|
||||
assert pipeline_result.hits == hits
|
||||
assert pipeline_result.is_empty is False
|
||||
assert pipeline_result.pipeline_name == "test_pipeline"
|
||||
|
||||
def test_pipeline_result_is_empty(self):
|
||||
"""[AC-AISVC-RES-01] Pipeline result should detect empty results."""
|
||||
pipeline_result = PipelineResult(
|
||||
retrieval_result=RetrievalResult(hits=[]),
|
||||
)
|
||||
assert pipeline_result.is_empty is True
|
||||
|
||||
|
||||
class TestDefaultPipeline:
|
||||
"""[AC-AISVC-RES-01] Tests for default pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
"""Create a mock optimized retriever."""
|
||||
retriever = AsyncMock()
|
||||
retriever.retrieve = AsyncMock(return_value=RetrievalResult(
|
||||
hits=[
|
||||
RetrievalHit(text="result 1", score=0.9, source="default", metadata={}),
|
||||
],
|
||||
diagnostics={"test": True},
|
||||
))
|
||||
retriever.health_check = AsyncMock(return_value=True)
|
||||
retriever._two_stage_enabled = True
|
||||
retriever._hybrid_enabled = True
|
||||
return retriever
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self, mock_retriever):
|
||||
"""Create a default pipeline with mock retriever."""
|
||||
return DefaultPipeline(optimized_retriever=mock_retriever)
|
||||
|
||||
def test_pipeline_name(self, pipeline):
|
||||
"""[AC-AISVC-RES-01] Pipeline should have correct name."""
|
||||
assert pipeline.name == "default_pipeline"
|
||||
|
||||
def test_pipeline_description(self, pipeline):
|
||||
"""[AC-AISVC-RES-01] Pipeline should have description."""
|
||||
assert "默认" in pipeline.description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve(self, pipeline, mock_retriever):
|
||||
"""[AC-AISVC-RES-01] Should retrieve results using optimized retriever."""
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id="tenant_1",
|
||||
query="test query",
|
||||
)
|
||||
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
||||
|
||||
result = await pipeline.retrieve(pipeline_ctx)
|
||||
|
||||
assert result.pipeline_name == "default_pipeline"
|
||||
assert len(result.hits) == 1
|
||||
assert result.diagnostics["retriever"] == "OptimizedRetriever"
|
||||
mock_retriever.retrieve.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever):
|
||||
"""[AC-AISVC-RES-04] Should apply metadata filter."""
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id="tenant_1",
|
||||
query="test query",
|
||||
)
|
||||
metadata_filter = MetadataFilterResult(
|
||||
filter_dict={"grade": "初一"},
|
||||
filter_mode=FilterMode.HARD,
|
||||
)
|
||||
pipeline_ctx = PipelineContext(
|
||||
retrieval_ctx=retrieval_ctx,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
result = await pipeline.retrieve(pipeline_ctx)
|
||||
|
||||
assert result.metadata_filter_applied is True
|
||||
call_args = mock_retriever.retrieve.call_args[0][0]
|
||||
assert call_args.metadata_filter == {"grade": "初一"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, pipeline, mock_retriever):
|
||||
"""[AC-AISVC-RES-01] Should check health."""
|
||||
result = await pipeline.health_check()
|
||||
assert result is True
|
||||
mock_retriever.health_check.assert_called_once()
|
||||
|
||||
|
||||
class TestEnhancedPipeline:
|
||||
"""[AC-AISVC-RES-02] Tests for enhanced pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client(self):
|
||||
"""Create a mock Qdrant client."""
|
||||
client = AsyncMock()
|
||||
client.search = AsyncMock(return_value=[
|
||||
{"id": "1", "score": 0.9, "payload": {"text": "result 1"}},
|
||||
])
|
||||
client.get_client = AsyncMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_provider(self):
|
||||
"""Create a mock embedding provider."""
|
||||
provider = AsyncMock()
|
||||
provider.embed_query = AsyncMock()
|
||||
provider.embed_query.return_value = MagicMock(
|
||||
embedding_full=[0.1] * 768,
|
||||
)
|
||||
provider.embed = AsyncMock(return_value=[0.1] * 768)
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self, mock_qdrant_client, mock_embedding_provider):
|
||||
"""Create an enhanced pipeline with mocks."""
|
||||
pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client)
|
||||
pipeline._embedding_provider = mock_embedding_provider
|
||||
return pipeline
|
||||
|
||||
def test_pipeline_name(self, pipeline):
|
||||
"""[AC-AISVC-RES-02] Pipeline should have correct name."""
|
||||
assert pipeline.name == "enhanced_pipeline"
|
||||
|
||||
def test_pipeline_description(self, pipeline):
|
||||
"""[AC-AISVC-RES-02] Pipeline should have description."""
|
||||
assert "增强" in pipeline.description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_basic(self, pipeline):
|
||||
"""[AC-AISVC-RES-02] Should retrieve results using hybrid search."""
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id="tenant_1",
|
||||
query="test query",
|
||||
)
|
||||
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
||||
|
||||
result = await pipeline.retrieve(pipeline_ctx)
|
||||
|
||||
assert result.pipeline_name == "enhanced_pipeline"
|
||||
assert result.diagnostics is not None
|
||||
|
||||
|
||||
class TestStrategyRouter:
|
||||
"""[AC-AISVC-RES-01~03] Tests for strategy router."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_default_pipeline(self):
|
||||
"""Create a mock default pipeline."""
|
||||
pipeline = AsyncMock(spec=DefaultPipeline)
|
||||
pipeline.name = "default_pipeline"
|
||||
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
|
||||
retrieval_result=RetrievalResult(hits=[]),
|
||||
pipeline_name="default_pipeline",
|
||||
))
|
||||
return pipeline
|
||||
|
||||
@pytest.fixture
|
||||
def mock_enhanced_pipeline(self):
|
||||
"""Create a mock enhanced pipeline."""
|
||||
pipeline = AsyncMock(spec=EnhancedPipeline)
|
||||
pipeline.name = "enhanced_pipeline"
|
||||
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
|
||||
retrieval_result=RetrievalResult(hits=[]),
|
||||
pipeline_name="enhanced_pipeline",
|
||||
))
|
||||
return pipeline
|
||||
|
||||
@pytest.fixture
|
||||
def router(self, mock_default_pipeline, mock_enhanced_pipeline):
|
||||
"""Create a strategy router with mock pipelines."""
|
||||
config = RetrievalStrategyConfig()
|
||||
return StrategyRouter(
|
||||
config=config,
|
||||
default_pipeline=mock_default_pipeline,
|
||||
enhanced_pipeline=mock_enhanced_pipeline,
|
||||
)
|
||||
|
||||
def test_route_default_strategy(self, router):
|
||||
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
|
||||
import asyncio
|
||||
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
|
||||
|
||||
assert decision.strategy == StrategyType.DEFAULT
|
||||
assert decision.reason == "default_strategy"
|
||||
|
||||
def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline):
|
||||
"""[AC-AISVC-RES-02] Should route to enhanced strategy when configured."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
router = StrategyRouter(
|
||||
config=config,
|
||||
default_pipeline=mock_default_pipeline,
|
||||
enhanced_pipeline=mock_enhanced_pipeline,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
|
||||
|
||||
assert decision.strategy == StrategyType.ENHANCED
|
||||
assert decision.reason == "active_strategy=enhanced"
|
||||
|
||||
def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline):
|
||||
"""[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants."""
|
||||
config = RetrievalStrategyConfig(
|
||||
active_strategy=StrategyType.DEFAULT,
|
||||
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
|
||||
)
|
||||
router = StrategyRouter(
|
||||
config=config,
|
||||
default_pipeline=mock_default_pipeline,
|
||||
enhanced_pipeline=mock_enhanced_pipeline,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a"))
|
||||
assert decision.strategy == StrategyType.ENHANCED
|
||||
assert decision.grayscale_hit is True
|
||||
|
||||
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b"))
|
||||
assert decision.strategy == StrategyType.DEFAULT
|
||||
|
||||
def test_update_config(self, router):
|
||||
"""[AC-AISVC-RES-02] Should update config."""
|
||||
new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
router.update_config(new_config)
|
||||
|
||||
assert router.get_config().active_strategy == StrategyType.ENHANCED
|
||||
|
||||
|
||||
class TestModeRouter:
|
||||
"""[AC-AISVC-RES-09~15] Tests for mode router."""
|
||||
|
||||
@pytest.fixture
|
||||
def router(self):
|
||||
"""Create a mode router."""
|
||||
return ModeRouter()
|
||||
|
||||
def test_decide_react_mode(self):
|
||||
"""[AC-AISVC-RES-10] Should decide react when configured."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
|
||||
router = ModeRouter(config)
|
||||
|
||||
decision = router.decide("any query")
|
||||
|
||||
assert decision.mode == RuntimeMode.REACT
|
||||
assert decision.reason == "runtime_mode=react"
|
||||
|
||||
def test_decide_direct_mode(self, router):
|
||||
"""[AC-AISVC-RES-09] Should decide direct when configured."""
|
||||
decision = router.decide("any query")
|
||||
|
||||
assert decision.mode == RuntimeMode.DIRECT
|
||||
assert decision.reason == "runtime_mode=direct"
|
||||
|
||||
def test_decide_auto_short_query_high_confidence(self):
|
||||
"""[AC-AISVC-RES-12] Auto with short query and high confidence should use direct."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||
router = ModeRouter(config)
|
||||
|
||||
decision = router.decide("短问题", confidence=0.8)
|
||||
|
||||
assert decision.mode == RuntimeMode.DIRECT
|
||||
|
||||
def test_decide_auto_low_confidence(self):
|
||||
"""[AC-AISVC-RES-13] Auto with low confidence should use react."""
|
||||
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||
router = ModeRouter(config)
|
||||
|
||||
decision = router.decide("any query", confidence=0.3)
|
||||
|
||||
assert decision.mode == RuntimeMode.REACT
|
||||
|
||||
def test_should_fallback_to_react_empty_results(self, router):
|
||||
"""[AC-AISVC-RES-14] Should fallback to react on empty results."""
|
||||
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
|
||||
|
||||
assert router.should_fallback_to_react(result) is True
|
||||
|
||||
def test_should_fallback_to_react_low_score(self, router):
|
||||
"""[AC-AISVC-RES-14] Should fallback to react on low score."""
|
||||
result = PipelineResult(
|
||||
retrieval_result=RetrievalResult(
|
||||
hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})],
|
||||
),
|
||||
)
|
||||
|
||||
assert router.should_fallback_to_react(result) is True
|
||||
|
||||
def test_should_not_fallback_to_react_disabled(self):
|
||||
"""[AC-AISVC-RES-14] Should not fallback when disabled."""
|
||||
config = ModeRouterConfig(direct_fallback_on_low_confidence=False)
|
||||
router = ModeRouter(config)
|
||||
|
||||
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
|
||||
|
||||
assert router.should_fallback_to_react(result) is False
|
||||
|
||||
|
||||
class TestRollbackManager:
|
||||
"""[AC-AISVC-RES-07] Tests for rollback manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
"""Create a rollback manager."""
|
||||
return RollbackManager()
|
||||
|
||||
def test_rollback_from_enhanced(self, manager):
|
||||
"""[AC-AISVC-RES-07] Should rollback from enhanced to default."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
manager.update_config(config)
|
||||
|
||||
result = manager.rollback(
|
||||
trigger=RollbackTrigger.MANUAL,
|
||||
reason="Testing rollback",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.previous_strategy == StrategyType.ENHANCED
|
||||
assert result.current_strategy == StrategyType.DEFAULT
|
||||
assert result.audit_log is not None
|
||||
|
||||
def test_rollback_already_default(self, manager):
|
||||
"""[AC-AISVC-RES-07] Should not rollback when already on default."""
|
||||
result = manager.rollback(
|
||||
trigger=RollbackTrigger.MANUAL,
|
||||
reason="Testing rollback",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.reason == "Already on default strategy"
|
||||
|
||||
def test_check_and_rollback_latency(self, manager):
|
||||
"""[AC-AISVC-RES-08] Should rollback on high latency."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
manager.update_config(config)
|
||||
|
||||
result = manager.check_and_rollback(
|
||||
metrics={"latency_ms": 3000.0},
|
||||
tenant_id="tenant_1",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.trigger == RollbackTrigger.PERFORMANCE
|
||||
|
||||
def test_check_and_rollback_error_rate(self, manager):
|
||||
"""[AC-AISVC-RES-08] Should rollback on high error rate."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
manager.update_config(config)
|
||||
|
||||
result = manager.check_and_rollback(
|
||||
metrics={"error_rate": 0.1},
|
||||
tenant_id="tenant_1",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.trigger == RollbackTrigger.ERROR
|
||||
|
||||
def test_check_and_rollback_ok(self, manager):
|
||||
"""[AC-AISVC-RES-08] Should not rollback when metrics are ok."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
manager.update_config(config)
|
||||
|
||||
result = manager.check_and_rollback(
|
||||
metrics={"latency_ms": 100.0, "error_rate": 0.01},
|
||||
tenant_id="tenant_1",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_audit_logs(self, manager):
|
||||
"""[AC-AISVC-RES-07] Should get audit logs."""
|
||||
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||
manager.update_config(config)
|
||||
manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test")
|
||||
|
||||
logs = manager.get_audit_logs()
|
||||
|
||||
assert len(logs) == 1
|
||||
assert logs[0].action == "rollback"
|
||||
|
||||
def test_record_audit(self, manager):
|
||||
"""[AC-AISVC-RES-07] Should record audit log."""
|
||||
log = manager.record_audit(
|
||||
action="test_action",
|
||||
details={"reason": "Testing"},
|
||||
tenant_id="tenant_1",
|
||||
)
|
||||
|
||||
assert log.action == "test_action"
|
||||
assert log.tenant_id == "tenant_1"
|
||||
|
||||
|
||||
class TestSingletonInstances:
|
||||
"""Tests for singleton instance getters."""
|
||||
|
||||
def test_get_mode_router_singleton(self):
|
||||
"""Should return same mode router instance."""
|
||||
from app.services.retrieval.strategy.mode_router import _mode_router
|
||||
|
||||
import app.services.retrieval.strategy.mode_router as module
|
||||
module._mode_router = None
|
||||
|
||||
router1 = get_mode_router()
|
||||
router2 = get_mode_router()
|
||||
|
||||
assert router1 is router2
|
||||
|
||||
def test_get_rollback_manager_singleton(self):
|
||||
"""Should return same rollback manager instance."""
|
||||
from app.services.retrieval.strategy.rollback_manager import _rollback_manager
|
||||
|
||||
import app.services.retrieval.strategy.rollback_manager as module
|
||||
module._rollback_manager = None
|
||||
|
||||
manager1 = get_rollback_manager()
|
||||
manager2 = get_rollback_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
---
|
||||
context:
|
||||
module: "ai-service"
|
||||
feature: "AISVC-RES"
|
||||
status: "✅已完成"
|
||||
version: "0.9.0"
|
||||
active_ac_range: "AC-AISVC-RES-01~15"
|
||||
|
||||
spec_references:
|
||||
requirements: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/requirements.md"
|
||||
design: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/design.md"
|
||||
tasks: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/tasks.md"
|
||||
openapi_provider: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/openapi.provider.yaml"
|
||||
active_version: "0.1.0"
|
||||
|
||||
overall_progress:
|
||||
- "[x] Phase 1: Schema与数据模型定义 (100%) [API与校验]"
|
||||
- "[x] Phase 2: 策略服务层实现 (100%) [策略层与配置]"
|
||||
- "[x] Phase 3: 审计日志与指标埋点 (100%) [观测与灰度验证]"
|
||||
- "[x] Phase 4: API端点实现 (100%) [API与校验]"
|
||||
- "[x] Phase 5: 单元测试与验证 (100%) [验收]"
|
||||
- "[x] Phase 6: 检索策略Pipeline实现 (100%) [检索与嵌入策略]"
|
||||
|
||||
current_phase:
|
||||
goal: "检索与嵌入策略化改造已完成"
|
||||
sub_tasks:
|
||||
- "[x] 创建 Schema 模型 (app/schemas/retrieval_strategy.py)"
|
||||
- "[x] 创建策略服务层 (app/services/retrieval/strategy_service.py)"
|
||||
- "[x] 创建审计日志服务 (app/services/retrieval/strategy_audit.py)"
|
||||
- "[x] 创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)"
|
||||
- "[x] 创建 API 端点 (app/api/admin/retrieval_strategy.py)"
|
||||
- "[x] 创建策略配置模型 (app/services/retrieval/strategy/config.py)"
|
||||
- "[x] 实现 DefaultPipeline(复用现有逻辑)"
|
||||
- "[x] 实现 EnhancedPipeline(新端到端流程)"
|
||||
- "[x] 实现元数据推断统一入口 (MetadataInferenceService)"
|
||||
- "[x] 实现 StrategyRouter 和 ModeRouter"
|
||||
- "[x] 实现 Dense + Keyword + RRF 组合检索"
|
||||
- "[x] 实现可选重排与降级开关"
|
||||
- "[x] 实现 RollbackManager(回退与审计)"
|
||||
- "[x] 实现策略 API 接口"
|
||||
- "[x] 创建单元测试 (tests/test_retrieval_strategy_v2.py)"
|
||||
- "[x] 运行单元测试验证 (51 passed)"
|
||||
|
||||
next_action:
|
||||
immediate: "任务已完成,可进行集成测试"
|
||||
details:
|
||||
file: "ai-service/app/services/retrieval/strategy/__init__.py:1"
|
||||
action: "模块已完整实现,可通过 API 接口测试策略切换功能"
|
||||
reference: "http://localhost:8000/docs"
|
||||
constraints: "新策略可配置启用,不影响默认策略"
|
||||
|
||||
technical_context:
|
||||
module_structure: |
|
||||
ai-service/app/
|
||||
├── api/admin/retrieval_strategy.py (API端点)
|
||||
├── schemas/retrieval_strategy.py (Schema模型)
|
||||
└── services/retrieval/
|
||||
├── strategy_service.py (策略服务)
|
||||
├── strategy_audit.py (审计日志)
|
||||
├── strategy_metrics.py (指标埋点)
|
||||
└── strategy/ (新增 - 策略模块)
|
||||
├── __init__.py (模块导出)
|
||||
├── config.py (策略配置模型)
|
||||
├── pipeline_base.py (Pipeline基类)
|
||||
├── default_pipeline.py (默认策略Pipeline)
|
||||
├── enhanced_pipeline.py (增强策略Pipeline)
|
||||
├── metadata_inference.py (元数据推断统一入口)
|
||||
├── strategy_router.py (策略路由器)
|
||||
├── mode_router.py (模式路由器)
|
||||
└── rollback_manager.py (回退管理器)
|
||||
└── tests/
|
||||
├── test_retrieval_strategy.py (单元测试 - 原有)
|
||||
└── test_retrieval_strategy_v2.py (单元测试 - 新增)
|
||||
key_decisions:
|
||||
- decision: "使用内存存储策略状态,后续可扩展为持久化"
|
||||
reason: "快速实现,满足灰度验证需求"
|
||||
impact: "服务重启后策略状态重置为默认值"
|
||||
- decision: "审计日志使用结构化日志记录"
|
||||
reason: "与现有日志体系一致,便于检索"
|
||||
impact: "需要配置日志聚合系统收集审计日志"
|
||||
- decision: "API与现有strategy_router.py互补"
|
||||
reason: "strategy_router.py负责检索路由逻辑,新增的API负责策略管理"
|
||||
impact: "两者协同工作,API提供管理界面"
|
||||
- decision: "DefaultPipeline 复用现有 OptimizedRetriever 逻辑"
|
||||
reason: "保持线上行为不变,最小化改动风险"
|
||||
impact: "新策略与旧策略完全隔离,可独立灰度"
|
||||
- decision: "EnhancedPipeline 实现新端到端流程"
|
||||
reason: "支持 Dense + Keyword + RRF 组合检索,可选重排"
|
||||
impact: "需要配置启用,不影响默认策略"
|
||||
- decision: "元数据推断统一入口处理 hard/soft filter"
|
||||
reason: "新旧策略共享同一推断逻辑,确保一致性"
|
||||
impact: "置信度高用硬过滤,置信度低用软过滤/加权"
|
||||
code_snippets: |
|
||||
# 使用示例
|
||||
from app.services.retrieval.strategy import (
|
||||
get_strategy_router,
|
||||
get_mode_router,
|
||||
get_rollback_manager,
|
||||
)
|
||||
|
||||
# 策略路由
|
||||
router = get_strategy_router()
|
||||
decision = await router.route(tenant_id, user_id)
|
||||
result = await decision.pipeline.retrieve(ctx)
|
||||
|
||||
# 模式路由
|
||||
mode_router = get_mode_router()
|
||||
mode_decision = mode_router.decide(query, confidence=0.8)
|
||||
|
||||
# 回退管理
|
||||
rollback = get_rollback_manager()
|
||||
rollback.rollback(trigger="manual", reason="测试回退")
|
||||
|
||||
session_history:
|
||||
- session: "Session #1 (2026-03-10)"
|
||||
completed:
|
||||
- "创建 Schema 模型 (app/schemas/retrieval_strategy.py)"
|
||||
- "创建策略服务层 (app/services/retrieval/strategy_service.py)"
|
||||
- "创建审计日志服务 (app/services/retrieval/strategy_audit.py)"
|
||||
- "创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)"
|
||||
- "创建 API 端点 (app/api/admin/retrieval_strategy.py)"
|
||||
- "更新 admin __init__.py 注册新路由"
|
||||
- "更新 main.py 注册新路由"
|
||||
- "创建单元测试 (tests/test_retrieval_strategy.py)"
|
||||
- "运行单元测试验证 (46 passed)"
|
||||
changes:
|
||||
- "新增: ai-service/app/schemas/retrieval_strategy.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy_service.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy_audit.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy_metrics.py"
|
||||
- "新增: ai-service/app/api/admin/retrieval_strategy.py"
|
||||
- "修改: ai-service/app/api/admin/__init__.py"
|
||||
- "修改: ai-service/app/main.py"
|
||||
- "新增: ai-service/tests/test_retrieval_strategy.py"
|
||||
status: "✅ 任务完成"
|
||||
- session: "Session #2 (2026-03-10) - 检索策略Pipeline实现"
|
||||
completed:
|
||||
- "实现策略配置模型 (config.py)"
|
||||
- "实现 Pipeline 基类 (pipeline_base.py)"
|
||||
- "实现 DefaultPipeline(复用现有逻辑)"
|
||||
- "实现 EnhancedPipeline(新端到端流程)"
|
||||
- "实现 MetadataInferenceService"
|
||||
- "实现 StrategyRouter"
|
||||
- "实现 ModeRouter"
|
||||
- "实现 RollbackManager"
|
||||
- "更新模块导出 (__init__.py)"
|
||||
- "创建单元测试 (tests/test_retrieval_strategy_v2.py)"
|
||||
- "运行单元测试验证 (51 passed)"
|
||||
changes:
|
||||
- "新增: ai-service/app/services/retrieval/strategy/__init__.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/config.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/pipeline_base.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/default_pipeline.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/enhanced_pipeline.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/metadata_inference.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/strategy_router.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/mode_router.py"
|
||||
- "新增: ai-service/app/services/retrieval/strategy/rollback_manager.py"
|
||||
- "新增: ai-service/tests/test_retrieval_strategy_v2.py"
|
||||
status: "✅ 任务完成"
|
||||
|
||||
startup_guide:
|
||||
- "Step 1: 读取本进度文档(了解当前位置与下一步)"
|
||||
- "Step 2: 读取 spec_references 中定义的模块规范(了解业务与接口约束)"
|
||||
- "Step 3: 通过 API 接口测试策略切换功能"
|
||||
- "Step 4: 运行单元测试验证: pytest tests/test_retrieval_strategy_v2.py -v"
|
||||
---
|
||||
Loading…
Reference in New Issue