[AC-AISVC-02, AC-AISVC-16] 多个需求合并 #1
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
Unit tests for Retrieval layer.
|
||||
[AC-AISVC-10, AC-AISVC-16, AC-AISVC-17] Tests for vector retrieval with tenant isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
||||
from app.services.retrieval.vector_retriever import VectorRetriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Create a mock QdrantClient."""
|
||||
client = AsyncMock()
|
||||
client.search = AsyncMock()
|
||||
client.get_collection_name = MagicMock(side_effect=lambda tenant_id: f"kb_{tenant_id}")
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval_context():
|
||||
"""Create a sample RetrievalContext."""
|
||||
return RetrievalContext(
|
||||
tenant_id="tenant_a",
|
||||
query="What is the product price?",
|
||||
session_id="session_123",
|
||||
channel_type="wechat",
|
||||
metadata={"user_id": "user_123"},
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalContext:
|
||||
"""
|
||||
[AC-AISVC-16] Tests for retrieval context.
|
||||
"""
|
||||
|
||||
def test_retrieval_context_creation(self):
|
||||
"""
|
||||
[AC-AISVC-16] Should create retrieval context with all fields.
|
||||
"""
|
||||
ctx = RetrievalContext(
|
||||
tenant_id="tenant_a",
|
||||
query="Test query",
|
||||
session_id="session_123",
|
||||
channel_type="wechat",
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
assert ctx.tenant_id == "tenant_a"
|
||||
assert ctx.query == "Test query"
|
||||
assert ctx.session_id == "session_123"
|
||||
assert ctx.channel_type == "wechat"
|
||||
assert ctx.metadata == {"key": "value"}
|
||||
|
||||
def test_retrieval_context_minimal(self):
|
||||
"""
|
||||
[AC-AISVC-16] Should create retrieval context with minimal fields.
|
||||
"""
|
||||
ctx = RetrievalContext(
|
||||
tenant_id="tenant_a",
|
||||
query="Test query",
|
||||
)
|
||||
|
||||
assert ctx.tenant_id == "tenant_a"
|
||||
assert ctx.query == "Test query"
|
||||
assert ctx.session_id is None
|
||||
assert ctx.channel_type is None
|
||||
|
||||
|
||||
class TestRetrievalResult:
|
||||
"""
|
||||
[AC-AISVC-16, AC-AISVC-17] Tests for retrieval result.
|
||||
"""
|
||||
|
||||
def test_empty_result(self):
|
||||
"""
|
||||
[AC-AISVC-17] Empty result should indicate insufficient retrieval.
|
||||
"""
|
||||
result = RetrievalResult(hits=[])
|
||||
|
||||
assert result.is_empty is True
|
||||
assert result.max_score == 0.0
|
||||
assert result.hit_count == 0
|
||||
|
||||
def test_result_with_hits(self):
|
||||
"""
|
||||
[AC-AISVC-16] Result with hits should calculate correct statistics.
|
||||
"""
|
||||
hits = [
|
||||
RetrievalHit(text="Doc 1", score=0.9, source="vector"),
|
||||
RetrievalHit(text="Doc 2", score=0.7, source="vector"),
|
||||
]
|
||||
result = RetrievalResult(hits=hits)
|
||||
|
||||
assert result.is_empty is False
|
||||
assert result.max_score == 0.9
|
||||
assert result.hit_count == 2
|
||||
|
||||
def test_result_max_score(self):
|
||||
"""
|
||||
[AC-AISVC-17] Max score should be the highest among hits.
|
||||
"""
|
||||
hits = [
|
||||
RetrievalHit(text="Doc 1", score=0.5, source="vector"),
|
||||
RetrievalHit(text="Doc 2", score=0.95, source="vector"),
|
||||
RetrievalHit(text="Doc 3", score=0.3, source="vector"),
|
||||
]
|
||||
result = RetrievalResult(hits=hits)
|
||||
|
||||
assert result.max_score == 0.95
|
||||
|
||||
|
||||
class TestVectorRetrieverTenantIsolation:
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in vector retrieval.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_uses_tenant_collection(self, mock_qdrant_client, retrieval_context):
|
||||
"""
|
||||
[AC-AISVC-10] Search should use tenant-specific collection.
|
||||
"""
|
||||
mock_qdrant_client.search.return_value = [
|
||||
{"id": "1", "score": 0.9, "payload": {"text": "Answer 1", "source": "kb"}}
|
||||
]
|
||||
|
||||
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
|
||||
|
||||
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
|
||||
result = await retriever.retrieve(retrieval_context)
|
||||
|
||||
mock_qdrant_client.search.assert_called_once()
|
||||
call_args = mock_qdrant_client.search.call_args
|
||||
assert call_args.kwargs["tenant_id"] == "tenant_a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_tenants_separate_results(self, mock_qdrant_client):
|
||||
"""
|
||||
[AC-AISVC-11] Different tenants should get separate results.
|
||||
"""
|
||||
mock_qdrant_client.search.side_effect = [
|
||||
[{"id": "1", "score": 0.9, "payload": {"text": "Tenant A result"}}],
|
||||
[{"id": "2", "score": 0.8, "payload": {"text": "Tenant B result"}}],
|
||||
]
|
||||
|
||||
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
|
||||
|
||||
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
|
||||
ctx_a = RetrievalContext(tenant_id="tenant_a", query="query")
|
||||
ctx_b = RetrievalContext(tenant_id="tenant_b", query="query")
|
||||
|
||||
result_a = await retriever.retrieve(ctx_a)
|
||||
result_b = await retriever.retrieve(ctx_b)
|
||||
|
||||
assert result_a.hits[0].text == "Tenant A result"
|
||||
assert result_b.hits[0].text == "Tenant B result"
|
||||
|
||||
|
||||
class TestVectorRetrieverScoreThreshold:
|
||||
"""
|
||||
[AC-AISVC-17] Tests for score threshold filtering.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_by_score_threshold(self, mock_qdrant_client, retrieval_context):
|
||||
"""
|
||||
[AC-AISVC-17] Results below score threshold should be filtered.
|
||||
"""
|
||||
mock_qdrant_client.search.return_value = [
|
||||
{"id": "1", "score": 0.9, "payload": {"text": "High score"}},
|
||||
{"id": "2", "score": 0.5, "payload": {"text": "Low score"}},
|
||||
{"id": "3", "score": 0.8, "payload": {"text": "Medium score"}},
|
||||
]
|
||||
|
||||
retriever = VectorRetriever(
|
||||
qdrant_client=mock_qdrant_client,
|
||||
score_threshold=0.7,
|
||||
)
|
||||
|
||||
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
|
||||
result = await retriever.retrieve(retrieval_context)
|
||||
|
||||
assert len(result.hits) == 2
|
||||
assert all(hit.score >= 0.7 for hit in result.hits)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insufficient_hits_detection(self, mock_qdrant_client, retrieval_context):
|
||||
"""
|
||||
[AC-AISVC-17] Should detect insufficient retrieval when hits < min_hits.
|
||||
"""
|
||||
mock_qdrant_client.search.return_value = [
|
||||
{"id": "1", "score": 0.9, "payload": {"text": "Only one hit"}},
|
||||
]
|
||||
|
||||
retriever = VectorRetriever(
|
||||
qdrant_client=mock_qdrant_client,
|
||||
score_threshold=0.7,
|
||||
min_hits=2,
|
||||
)
|
||||
|
||||
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
|
||||
result = await retriever.retrieve(retrieval_context)
|
||||
|
||||
assert result.diagnostics["is_insufficient"] is True
|
||||
assert result.diagnostics["filtered_hits"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sufficient_hits_detection(self, mock_qdrant_client, retrieval_context):
|
||||
"""
|
||||
[AC-AISVC-17] Should detect sufficient retrieval when hits >= min_hits.
|
||||
"""
|
||||
mock_qdrant_client.search.return_value = [
|
||||
{"id": "1", "score": 0.9, "payload": {"text": "Hit 1"}},
|
||||
{"id": "2", "score": 0.85, "payload": {"text": "Hit 2"}},
|
||||
{"id": "3", "score": 0.8, "payload": {"text": "Hit 3"}},
|
||||
]
|
||||
|
||||
retriever = VectorRetriever(
|
||||
qdrant_client=mock_qdrant_client,
|
||||
score_threshold=0.7,
|
||||
min_hits=2,
|
||||
)
|
||||
|
||||
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
|
||||
result = await retriever.retrieve(retrieval_context)
|
||||
|
||||
assert result.diagnostics["is_insufficient"] is False
|
||||
assert result.diagnostics["filtered_hits"] == 3
|
||||
|
||||
|
||||
class TestVectorRetrieverHealthCheck:
|
||||
"""
|
||||
[AC-AISVC-16] Tests for retriever health check.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self, mock_qdrant_client):
|
||||
"""
|
||||
[AC-AISVC-16] Health check should return True when Qdrant is available.
|
||||
"""
|
||||
mock_qdrant = AsyncMock()
|
||||
mock_qdrant.get_collections = AsyncMock()
|
||||
mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant)
|
||||
|
||||
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
|
||||
is_healthy = await retriever.health_check()
|
||||
|
||||
assert is_healthy is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_failure(self, mock_qdrant_client):
|
||||
"""
|
||||
[AC-AISVC-16] Health check should return False when Qdrant is unavailable.
|
||||
"""
|
||||
mock_qdrant = AsyncMock()
|
||||
mock_qdrant.get_collections = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant)
|
||||
|
||||
retriever = VectorRetriever(qdrant_client=mock_qdrant_client)
|
||||
is_healthy = await retriever.health_check()
|
||||
|
||||
assert is_healthy is False
|
||||
Loading…
Reference in New Issue