ai-robot-core/ai-service/tests/test_memory.py

211 lines
6.9 KiB
Python

"""
Unit tests for Memory service.
[AC-AISVC-10, AC-AISVC-11, AC-AISVC-13] Tests for multi-tenant session and message management.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import ChatMessage, ChatSession
from app.services.memory import MemoryService
@pytest.fixture
def mock_session():
"""Create a mock AsyncSession."""
session = AsyncMock(spec=AsyncSession)
session.add = MagicMock()
session.flush = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def memory_service(mock_session):
"""Create MemoryService with mocked session."""
return MemoryService(mock_session)
class TestMemoryServiceTenantIsolation:
"""
[AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in memory service.
"""
@pytest.mark.asyncio
async def test_get_or_create_session_tenant_isolation(self, memory_service, mock_session):
"""
[AC-AISVC-11] Different tenants with same session_id should have separate sessions.
"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
session1 = await memory_service.get_or_create_session(
tenant_id="tenant_a",
session_id="session_123",
)
session2 = await memory_service.get_or_create_session(
tenant_id="tenant_b",
session_id="session_123",
)
assert session1.tenant_id == "tenant_a"
assert session2.tenant_id == "tenant_b"
assert session1.session_id == "session_123"
assert session2.session_id == "session_123"
@pytest.mark.asyncio
async def test_load_history_tenant_isolation(self, memory_service, mock_session):
"""
[AC-AISVC-11] Loading history should only return messages for the specific tenant.
"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Hello"),
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
messages = await memory_service.load_history(
tenant_id="tenant_a",
session_id="session_123",
)
assert len(messages) == 1
assert messages[0].tenant_id == "tenant_a"
@pytest.mark.asyncio
async def test_append_message_tenant_scoped(self, memory_service, mock_session):
"""
[AC-AISVC-10, AC-AISVC-13] Appended messages should be scoped to tenant.
"""
message = await memory_service.append_message(
tenant_id="tenant_a",
session_id="session_123",
role="user",
content="Test message",
)
assert message.tenant_id == "tenant_a"
assert message.session_id == "session_123"
assert message.role == "user"
assert message.content == "Test message"
class TestMemoryServiceSessionManagement:
"""
[AC-AISVC-13] Tests for session-based memory management.
"""
@pytest.mark.asyncio
async def test_get_existing_session(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should return existing session if it exists.
"""
existing_session = ChatSession(
tenant_id="tenant_a",
session_id="session_123",
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_session
mock_session.execute = AsyncMock(return_value=mock_result)
session = await memory_service.get_or_create_session(
tenant_id="tenant_a",
session_id="session_123",
)
assert session.tenant_id == "tenant_a"
assert session.session_id == "session_123"
@pytest.mark.asyncio
async def test_create_new_session(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should create new session if it doesn't exist.
"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
session = await memory_service.get_or_create_session(
tenant_id="tenant_a",
session_id="session_new",
channel_type="wechat",
metadata={"user_id": "user_123"},
)
assert session.tenant_id == "tenant_a"
assert session.session_id == "session_new"
assert session.channel_type == "wechat"
@pytest.mark.asyncio
async def test_append_multiple_messages(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should append multiple messages in batch.
"""
messages_data = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
messages = await memory_service.append_messages(
tenant_id="tenant_a",
session_id="session_123",
messages=messages_data,
)
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[1].role == "assistant"
@pytest.mark.asyncio
async def test_load_history_with_limit(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should limit the number of messages returned.
"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content=f"Msg {i}")
for i in range(5)
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
messages = await memory_service.load_history(
tenant_id="tenant_a",
session_id="session_123",
limit=3,
)
assert len(messages) == 5
class TestMemoryServiceClearHistory:
"""
[AC-AISVC-13] Tests for clearing session history.
"""
@pytest.mark.asyncio
async def test_clear_history_tenant_scoped(self, memory_service, mock_session):
"""
[AC-AISVC-11] Clearing history should only affect the specified tenant's messages.
"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Msg 1"),
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="assistant", content="Msg 2"),
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
count = await memory_service.clear_history(
tenant_id="tenant_a",
session_id="session_123",
)
assert count == 2