108 lines
2.7 KiB
Python
108 lines
2.7 KiB
Python
"""
|
|
Retrieval layer for AI Service.
|
|
[AC-AISVC-16] Abstract base class for retrievers with plugin point support.
|
|
"""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.retrieval.metadata import TagFilter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RetrievalContext:
|
|
"""
|
|
[AC-AISVC-16] Context for retrieval operations.
|
|
Contains all necessary information for retrieval plugins.
|
|
"""
|
|
|
|
tenant_id: str
|
|
query: str
|
|
session_id: str | None = None
|
|
channel_type: str | None = None
|
|
metadata: dict[str, Any] | None = None
|
|
tag_filter: "TagFilter | None" = None
|
|
kb_ids: list[str] | None = None
|
|
|
|
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
|
|
"""获取标签过滤器的字典表示"""
|
|
if self.tag_filter and not self.tag_filter.is_empty():
|
|
return self.tag_filter.fields
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class RetrievalHit:
|
|
"""
|
|
[AC-AISVC-16] Single retrieval result hit.
|
|
Unified structure for all retriever types.
|
|
"""
|
|
|
|
text: str
|
|
score: float
|
|
source: str
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class RetrievalResult:
|
|
"""
|
|
[AC-AISVC-16] Result from retrieval operation.
|
|
Contains hits and optional diagnostics.
|
|
"""
|
|
|
|
hits: list[RetrievalHit] = field(default_factory=list)
|
|
diagnostics: dict[str, Any] | None = None
|
|
|
|
@property
|
|
def is_empty(self) -> bool:
|
|
"""Check if no hits were found."""
|
|
return len(self.hits) == 0
|
|
|
|
@property
|
|
def max_score(self) -> float:
|
|
"""Get the maximum score among hits."""
|
|
if not self.hits:
|
|
return 0.0
|
|
return max(hit.score for hit in self.hits)
|
|
|
|
@property
|
|
def hit_count(self) -> int:
|
|
"""Get the number of hits."""
|
|
return len(self.hits)
|
|
|
|
|
|
class BaseRetriever(ABC):
|
|
"""
|
|
[AC-AISVC-16] Abstract base class for retrievers.
|
|
Provides plugin point for different retrieval strategies (Vector, Graph, Hybrid).
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
|
"""
|
|
[AC-AISVC-16] Retrieve relevant documents for the given context.
|
|
|
|
Args:
|
|
ctx: Retrieval context containing tenant_id, query, and optional metadata.
|
|
|
|
Returns:
|
|
RetrievalResult with hits and optional diagnostics.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def health_check(self) -> bool:
|
|
"""
|
|
Check if the retriever is healthy and ready to serve requests.
|
|
|
|
Returns:
|
|
True if healthy, False otherwise.
|
|
"""
|
|
pass
|