[AC-AISVC-02, AC-AISVC-16] 多个需求合并 #1

Merged
MerCry merged 45 commits from feature/prompt-unification-and-logging into main 2026-02-25 17:17:35 +00:00
1 changed files with 264 additions and 0 deletions
Showing only changes of commit 92cef20a86 - Show all commits

View File

@ -0,0 +1,264 @@
"""
Unit tests for Retrieval layer.
[AC-AISVC-10, AC-AISVC-16, AC-AISVC-17] Tests for vector retrieval with tenant isolation.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
from app.services.retrieval.vector_retriever import VectorRetriever
@pytest.fixture
def mock_qdrant_client():
"""Create a mock QdrantClient."""
client = AsyncMock()
client.search = AsyncMock()
client.get_collection_name = MagicMock(side_effect=lambda tenant_id: f"kb_{tenant_id}")
return client
@pytest.fixture
def retrieval_context():
"""Create a sample RetrievalContext."""
return RetrievalContext(
tenant_id="tenant_a",
query="What is the product price?",
session_id="session_123",
channel_type="wechat",
metadata={"user_id": "user_123"},
)
class TestRetrievalContext:
"""
[AC-AISVC-16] Tests for retrieval context.
"""
def test_retrieval_context_creation(self):
"""
[AC-AISVC-16] Should create retrieval context with all fields.
"""
ctx = RetrievalContext(
tenant_id="tenant_a",
query="Test query",
session_id="session_123",
channel_type="wechat",
metadata={"key": "value"},
)
assert ctx.tenant_id == "tenant_a"
assert ctx.query == "Test query"
assert ctx.session_id == "session_123"
assert ctx.channel_type == "wechat"
assert ctx.metadata == {"key": "value"}
def test_retrieval_context_minimal(self):
"""
[AC-AISVC-16] Should create retrieval context with minimal fields.
"""
ctx = RetrievalContext(
tenant_id="tenant_a",
query="Test query",
)
assert ctx.tenant_id == "tenant_a"
assert ctx.query == "Test query"
assert ctx.session_id is None
assert ctx.channel_type is None
class TestRetrievalResult:
"""
[AC-AISVC-16, AC-AISVC-17] Tests for retrieval result.
"""
def test_empty_result(self):
"""
[AC-AISVC-17] Empty result should indicate insufficient retrieval.
"""
result = RetrievalResult(hits=[])
assert result.is_empty is True
assert result.max_score == 0.0
assert result.hit_count == 0
def test_result_with_hits(self):
"""
[AC-AISVC-16] Result with hits should calculate correct statistics.
"""
hits = [
RetrievalHit(text="Doc 1", score=0.9, source="vector"),
RetrievalHit(text="Doc 2", score=0.7, source="vector"),
]
result = RetrievalResult(hits=hits)
assert result.is_empty is False
assert result.max_score == 0.9
assert result.hit_count == 2
def test_result_max_score(self):
"""
[AC-AISVC-17] Max score should be the highest among hits.
"""
hits = [
RetrievalHit(text="Doc 1", score=0.5, source="vector"),
RetrievalHit(text="Doc 2", score=0.95, source="vector"),
RetrievalHit(text="Doc 3", score=0.3, source="vector"),
]
result = RetrievalResult(hits=hits)
assert result.max_score == 0.95
class TestVectorRetrieverTenantIsolation:
"""
[AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in vector retrieval.
"""
@pytest.mark.asyncio
async def test_search_uses_tenant_collection(self, mock_qdrant_client, retrieval_context):
"""
[AC-AISVC-10] Search should use tenant-specific collection.
"""
mock_qdrant_client.search.return_value = [
{"id": "1", "score": 0.9, "payload": {"text": "Answer 1", "source": "kb"}}
]
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
result = await retriever.retrieve(retrieval_context)
mock_qdrant_client.search.assert_called_once()
call_args = mock_qdrant_client.search.call_args
assert call_args.kwargs["tenant_id"] == "tenant_a"
@pytest.mark.asyncio
async def test_different_tenants_separate_results(self, mock_qdrant_client):
"""
[AC-AISVC-11] Different tenants should get separate results.
"""
mock_qdrant_client.search.side_effect = [
[{"id": "1", "score": 0.9, "payload": {"text": "Tenant A result"}}],
[{"id": "2", "score": 0.8, "payload": {"text": "Tenant B result"}}],
]
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
ctx_a = RetrievalContext(tenant_id="tenant_a", query="query")
ctx_b = RetrievalContext(tenant_id="tenant_b", query="query")
result_a = await retriever.retrieve(ctx_a)
result_b = await retriever.retrieve(ctx_b)
assert result_a.hits[0].text == "Tenant A result"
assert result_b.hits[0].text == "Tenant B result"
class TestVectorRetrieverScoreThreshold:
"""
[AC-AISVC-17] Tests for score threshold filtering.
"""
@pytest.mark.asyncio
async def test_filter_by_score_threshold(self, mock_qdrant_client, retrieval_context):
"""
[AC-AISVC-17] Results below score threshold should be filtered.
"""
mock_qdrant_client.search.return_value = [
{"id": "1", "score": 0.9, "payload": {"text": "High score"}},
{"id": "2", "score": 0.5, "payload": {"text": "Low score"}},
{"id": "3", "score": 0.8, "payload": {"text": "Medium score"}},
]
retriever = VectorRetriever(
qdrant_client=mock_qdrant_client,
score_threshold=0.7,
)
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
result = await retriever.retrieve(retrieval_context)
assert len(result.hits) == 2
assert all(hit.score >= 0.7 for hit in result.hits)
@pytest.mark.asyncio
async def test_insufficient_hits_detection(self, mock_qdrant_client, retrieval_context):
"""
[AC-AISVC-17] Should detect insufficient retrieval when hits < min_hits.
"""
mock_qdrant_client.search.return_value = [
{"id": "1", "score": 0.9, "payload": {"text": "Only one hit"}},
]
retriever = VectorRetriever(
qdrant_client=mock_qdrant_client,
score_threshold=0.7,
min_hits=2,
)
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
result = await retriever.retrieve(retrieval_context)
assert result.diagnostics["is_insufficient"] is True
assert result.diagnostics["filtered_hits"] == 1
@pytest.mark.asyncio
async def test_sufficient_hits_detection(self, mock_qdrant_client, retrieval_context):
"""
[AC-AISVC-17] Should detect sufficient retrieval when hits >= min_hits.
"""
mock_qdrant_client.search.return_value = [
{"id": "1", "score": 0.9, "payload": {"text": "Hit 1"}},
{"id": "2", "score": 0.85, "payload": {"text": "Hit 2"}},
{"id": "3", "score": 0.8, "payload": {"text": "Hit 3"}},
]
retriever = VectorRetriever(
qdrant_client=mock_qdrant_client,
score_threshold=0.7,
min_hits=2,
)
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
result = await retriever.retrieve(retrieval_context)
assert result.diagnostics["is_insufficient"] is False
assert result.diagnostics["filtered_hits"] == 3
class TestVectorRetrieverHealthCheck:
"""
[AC-AISVC-16] Tests for retriever health check.
"""
@pytest.mark.asyncio
async def test_health_check_success(self, mock_qdrant_client):
"""
[AC-AISVC-16] Health check should return True when Qdrant is available.
"""
mock_qdrant = AsyncMock()
mock_qdrant.get_collections = AsyncMock()
mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant)
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
is_healthy = await retriever.health_check()
assert is_healthy is True
@pytest.mark.asyncio
async def test_health_check_failure(self, mock_qdrant_client):
"""
[AC-AISVC-16] Health check should return False when Qdrant is unavailable.
"""
mock_qdrant = AsyncMock()
mock_qdrant.get_collections = AsyncMock(side_effect=Exception("Connection failed"))
mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant)
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
is_healthy = await retriever.health_check()
assert is_healthy is False