diff --git a/ai-service/app/services/cache/flow_cache.py b/ai-service/app/services/cache/flow_cache.py new file mode 100644 index 0000000..94eed79 --- /dev/null +++ b/ai-service/app/services/cache/flow_cache.py @@ -0,0 +1,242 @@ +""" +Flow Instance Cache Layer. +Provides Redis-based caching for FlowInstance to reduce database load. +""" + +import json +import logging +from typing import Any + +import redis.asyncio as redis + +from app.core.config import get_settings +from app.models.entities import FlowInstance + +logger = logging.getLogger(__name__) + + +class FlowCache: + """ + Redis cache layer for FlowInstance state management. + + Features: + - L1: In-memory cache (process-level, 5 min TTL) + - L2: Redis cache (shared, 1 hour TTL) + - Automatic fallback on cache miss + - Cache invalidation on flow completion/cancellation + + Key format: flow:{tenant_id}:{session_id} + TTL: 3600 seconds (1 hour) + """ + + # L1 cache: in-memory (process-level) + _local_cache: dict[str, tuple[FlowInstance, float]] = {} + _local_cache_ttl = 300 # 5 minutes + + def __init__(self, redis_client: redis.Redis | None = None): + self._redis = redis_client + self._settings = get_settings() + self._enabled = self._settings.redis_enabled + self._cache_ttl = 3600 # 1 hour + + async def _get_client(self) -> redis.Redis | None: + """Get or create Redis client.""" + if not self._enabled: + return None + if self._redis is None: + try: + self._redis = redis.from_url( + self._settings.redis_url, + encoding="utf-8", + decode_responses=True, + ) + except Exception as e: + logger.warning(f"[FlowCache] Failed to connect to Redis: {e}") + self._enabled = False + return None + return self._redis + + def _make_key(self, tenant_id: str, session_id: str) -> str: + """Generate cache key.""" + return f"flow:{tenant_id}:{session_id}" + + def _make_local_key(self, tenant_id: str, session_id: str) -> str: + """Generate local cache key.""" + return f"{tenant_id}:{session_id}" + + async def get( + self, + tenant_id: str, + session_id: str, + ) -> FlowInstance | None: + """ + Get FlowInstance from cache (L1 -> L2). + + Args: + tenant_id: Tenant ID for isolation + session_id: Session ID + + Returns: + Cached FlowInstance or None if not found + """ + # L1: Check local cache + local_key = self._make_local_key(tenant_id, session_id) + if local_key in self._local_cache: + instance, timestamp = self._local_cache[local_key] + import time + if time.time() - timestamp < self._local_cache_ttl: + logger.debug(f"[FlowCache] L1 hit: {local_key}") + return instance + else: + # Expired, remove from L1 + del self._local_cache[local_key] + + # L2: Check Redis cache + client = await self._get_client() + if client is None: + return None + + key = self._make_key(tenant_id, session_id) + + try: + data = await client.get(key) + if data: + logger.debug(f"[FlowCache] L2 hit: {key}") + instance_dict = json.loads(data) + instance = self._deserialize_instance(instance_dict) + + # Populate L1 cache + import time + self._local_cache[local_key] = (instance, time.time()) + + return instance + return None + except Exception as e: + logger.warning(f"[FlowCache] Failed to get from cache: {e}") + return None + + async def set( + self, + tenant_id: str, + session_id: str, + instance: FlowInstance, + ) -> bool: + """ + Set FlowInstance to cache (L1 + L2). + + Args: + tenant_id: Tenant ID for isolation + session_id: Session ID + instance: FlowInstance to cache + + Returns: + True if successful + """ + # L1: Update local cache + local_key = self._make_local_key(tenant_id, session_id) + import time + self._local_cache[local_key] = (instance, time.time()) + + # L2: Update Redis cache + client = await self._get_client() + if client is None: + return False + + key = self._make_key(tenant_id, session_id) + + try: + instance_dict = self._serialize_instance(instance) + await client.setex( + key, + self._cache_ttl, + json.dumps(instance_dict, default=str), + ) + logger.debug(f"[FlowCache] Set cache: {key}") + return True + except Exception as e: + logger.warning(f"[FlowCache] Failed to set cache: {e}") + return False + + async def delete( + self, + tenant_id: str, + session_id: str, + ) -> bool: + """ + Delete FlowInstance from cache (L1 + L2). + + Args: + tenant_id: Tenant ID for isolation + session_id: Session ID + + Returns: + True if successful + """ + # L1: Remove from local cache + local_key = self._make_local_key(tenant_id, session_id) + if local_key in self._local_cache: + del self._local_cache[local_key] + + # L2: Remove from Redis + client = await self._get_client() + if client is None: + return False + + key = self._make_key(tenant_id, session_id) + + try: + await client.delete(key) + logger.debug(f"[FlowCache] Deleted cache: {key}") + return True + except Exception as e: + logger.warning(f"[FlowCache] Failed to delete cache: {e}") + return False + + def _serialize_instance(self, instance: FlowInstance) -> dict[str, Any]: + """Serialize FlowInstance to dict.""" + return { + "id": str(instance.id), + "tenant_id": instance.tenant_id, + "session_id": instance.session_id, + "flow_id": str(instance.flow_id), + "current_step": instance.current_step, + "status": instance.status, + "context": instance.context, + "started_at": instance.started_at.isoformat() if instance.started_at else None, + "completed_at": instance.completed_at.isoformat() if instance.completed_at else None, + "updated_at": instance.updated_at.isoformat() if instance.updated_at else None, + } + + def _deserialize_instance(self, data: dict[str, Any]) -> FlowInstance: + """Deserialize dict to FlowInstance.""" + from datetime import datetime + from uuid import UUID + + return FlowInstance( + id=UUID(data["id"]), + tenant_id=data["tenant_id"], + session_id=data["session_id"], + flow_id=UUID(data["flow_id"]), + current_step=data["current_step"], + status=data["status"], + context=data.get("context"), + started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None, + completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None, + updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None, + ) + + async def close(self) -> None: + """Close Redis connection.""" + if self._redis: + await self._redis.close() + + +_flow_cache: FlowCache | None = None + + +def get_flow_cache() -> FlowCache: + """Get singleton FlowCache instance.""" + global _flow_cache + if _flow_cache is None: + _flow_cache = FlowCache() + return _flow_cache diff --git a/ai-service/tests/test_flow_cache.py b/ai-service/tests/test_flow_cache.py new file mode 100644 index 0000000..2f87637 --- /dev/null +++ b/ai-service/tests/test_flow_cache.py @@ -0,0 +1,181 @@ +""" +Unit tests for FlowCache. +""" + +import asyncio +import uuid +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.models.entities import FlowInstance, FlowInstanceStatus +from app.services.cache.flow_cache import FlowCache + + +@pytest.fixture +def mock_redis(): + """Mock Redis client.""" + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=None) + redis_mock.setex = AsyncMock(return_value=True) + redis_mock.delete = AsyncMock(return_value=1) + return redis_mock + + +@pytest.fixture +def flow_cache(mock_redis): + """FlowCache instance with mocked Redis.""" + cache = FlowCache(redis_client=mock_redis) + cache._enabled = True + return cache + + +@pytest.fixture +def sample_instance(): + """Sample FlowInstance for testing.""" + return FlowInstance( + id=uuid.uuid4(), + tenant_id="tenant-001", + session_id="session-001", + flow_id=uuid.uuid4(), + current_step=1, + status=FlowInstanceStatus.ACTIVE.value, + context={"inputs": []}, + started_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + + +@pytest.mark.asyncio +async def test_cache_miss(flow_cache, mock_redis): + """Test cache miss returns None.""" + mock_redis.get.return_value = None + + result = await flow_cache.get("tenant-001", "session-001") + + assert result is None + mock_redis.get.assert_called_once_with("flow:tenant-001:session-001") + + +@pytest.mark.asyncio +async def test_cache_hit_l2(flow_cache, mock_redis, sample_instance): + """Test L2 (Redis) cache hit.""" + import json + + # Mock Redis returning cached data + cached_data = { + "id": str(sample_instance.id), + "tenant_id": sample_instance.tenant_id, + "session_id": sample_instance.session_id, + "flow_id": str(sample_instance.flow_id), + "current_step": sample_instance.current_step, + "status": sample_instance.status, + "context": sample_instance.context, + "started_at": sample_instance.started_at.isoformat(), + "completed_at": None, + "updated_at": sample_instance.updated_at.isoformat(), + } + mock_redis.get.return_value = json.dumps(cached_data) + + result = await flow_cache.get("tenant-001", "session-001") + + assert result is not None + assert result.tenant_id == "tenant-001" + assert result.session_id == "session-001" + assert result.current_step == 1 + assert result.status == FlowInstanceStatus.ACTIVE.value + + +@pytest.mark.asyncio +async def test_cache_set(flow_cache, mock_redis, sample_instance): + """Test setting cache.""" + success = await flow_cache.set( + "tenant-001", + "session-001", + sample_instance, + ) + + assert success is True + mock_redis.setex.assert_called_once() + call_args = mock_redis.setex.call_args + assert call_args[0][0] == "flow:tenant-001:session-001" + assert call_args[0][1] == 3600 # TTL + + +@pytest.mark.asyncio +async def test_cache_delete(flow_cache, mock_redis): + """Test deleting cache.""" + success = await flow_cache.delete("tenant-001", "session-001") + + assert success is True + mock_redis.delete.assert_called_once_with("flow:tenant-001:session-001") + + +@pytest.mark.asyncio +async def test_l1_cache_hit(flow_cache, sample_instance): + """Test L1 (local memory) cache hit.""" + # Populate L1 cache + import time + local_key = "tenant-001:session-001" + flow_cache._local_cache[local_key] = (sample_instance, time.time()) + + # Should hit L1 without calling Redis + result = await flow_cache.get("tenant-001", "session-001") + + assert result is not None + assert result.tenant_id == "tenant-001" + assert result.session_id == "session-001" + + +@pytest.mark.asyncio +async def test_l1_cache_expiry(flow_cache, sample_instance): + """Test L1 cache expiry.""" + # Populate L1 cache with expired timestamp + import time + local_key = "tenant-001:session-001" + expired_time = time.time() - 400 # 400 seconds ago (> 300s TTL) + flow_cache._local_cache[local_key] = (sample_instance, expired_time) + + # Should miss L1 and try L2 + result = await flow_cache.get("tenant-001", "session-001") + + # L1 entry should be removed + assert local_key not in flow_cache._local_cache + + +@pytest.mark.asyncio +async def test_cache_disabled(sample_instance): + """Test cache behavior when Redis is disabled.""" + cache = FlowCache(redis_client=None) + cache._enabled = False + + # All operations should return None/False gracefully + result = await cache.get("tenant-001", "session-001") + assert result is None + + success = await cache.set("tenant-001", "session-001", sample_instance) + assert success is False + + success = await cache.delete("tenant-001", "session-001") + assert success is False + + +@pytest.mark.asyncio +async def test_serialize_deserialize(flow_cache, sample_instance): + """Test serialization and deserialization.""" + # Serialize + serialized = flow_cache._serialize_instance(sample_instance) + + assert serialized["tenant_id"] == "tenant-001" + assert serialized["session_id"] == "session-001" + assert serialized["current_step"] == 1 + assert serialized["status"] == FlowInstanceStatus.ACTIVE.value + + # Deserialize + deserialized = flow_cache._deserialize_instance(serialized) + + assert deserialized.tenant_id == sample_instance.tenant_id + assert deserialized.session_id == sample_instance.session_id + assert deserialized.current_step == sample_instance.current_step + assert deserialized.status == sample_instance.status diff --git a/ai-service/tests/test_flow_engine_script_generation.py b/ai-service/tests/test_flow_engine_script_generation.py new file mode 100644 index 0000000..27fe9d5 --- /dev/null +++ b/ai-service/tests/test_flow_engine_script_generation.py @@ -0,0 +1,375 @@ +""" +Unit tests for FlowEngine script generation. +[AC-IDS-03, AC-IDS-05, AC-IDS-11, AC-IDS-13] Test intent-driven script generation in FlowEngine. + +NOTE: These tests are for features not yet implemented in FlowEngine. +The _generate_step_content method and llm_client parameter are planned for AC-IDS-04. +""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import ( + FlowAdvanceResult, + FlowInstance, + FlowInstanceStatus, + ScriptFlow, + ScriptMode, +) +from app.services.flow.engine import FlowEngine + + +@pytest.mark.skip(reason="FlowEngine._generate_step_content not implemented yet - AC-IDS-04") +class TestFlowEngineScriptGeneration: + """[AC-IDS-03, AC-IDS-05, AC-IDS-11] Test cases for script generation in FlowEngine.""" + + @pytest.fixture + def mock_session(self): + """Create mock database session.""" + session = MagicMock(spec=AsyncSession) + session.execute = AsyncMock() + session.add = MagicMock() + session.flush = AsyncMock() + return session + + @pytest.fixture + def mock_llm_client(self): + """Create mock LLM client.""" + client = MagicMock() + client.generate_text = AsyncMock(return_value="您好,请问您贵姓?") + return client + + @pytest.fixture + def sample_flow_fixed(self): + """Create sample flow with fixed mode steps.""" + flow = ScriptFlow( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="测试流程", + steps=[ + { + "step_no": 1, + "content": "您好,请问有什么可以帮您?", + "script_mode": "fixed", + "wait_input": True, + }, + { + "step_no": 2, + "content": "感谢您的咨询,再见!", + "script_mode": "fixed", + "wait_input": False, + }, + ], + is_enabled=True, + ) + return flow + + @pytest.fixture + def sample_flow_flexible(self): + """Create sample flow with flexible mode steps.""" + flow = ScriptFlow( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="灵活话术流程", + steps=[ + { + "step_no": 1, + "script_mode": "flexible", + "intent": "获取用户姓名", + "intent_description": "礼貌询问用户姓名", + "script_constraints": ["必须礼貌", "语气自然"], + "content": "请问怎么称呼您?", + "wait_input": True, + }, + { + "step_no": 2, + "script_mode": "flexible", + "intent": "确认用户需求", + "script_constraints": ["简洁明了"], + "content": "请问您需要什么帮助?", + "wait_input": True, + }, + ], + is_enabled=True, + ) + return flow + + @pytest.fixture + def sample_flow_template(self): + """Create sample flow with template mode steps.""" + flow = ScriptFlow( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="模板话术流程", + steps=[ + { + "step_no": 1, + "script_mode": "template", + "content": "您好{user_name},请问您{inquiry_style}?", + "wait_input": True, + }, + ], + is_enabled=True, + ) + return flow + + @pytest.fixture + def sample_flow_mixed(self): + """Create sample flow with mixed mode steps.""" + flow = ScriptFlow( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="混合模式流程", + steps=[ + { + "step_no": 1, + "script_mode": "fixed", + "content": "您好,欢迎咨询!", + "wait_input": False, + }, + { + "step_no": 2, + "script_mode": "flexible", + "intent": "获取用户姓名", + "content": "请问怎么称呼您?", + "wait_input": True, + }, + { + "step_no": 3, + "script_mode": "template", + "content": "好的{user_name},请问您有什么需要帮助的吗?", + "wait_input": True, + }, + ], + is_enabled=True, + ) + return flow + + @pytest.mark.asyncio + async def test_generate_step_content_fixed_mode(self, mock_session, sample_flow_fixed): + """Test that fixed mode returns content directly.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + result = await engine._generate_step_content( + step=sample_flow_fixed.steps[0], + context=None, + history=None, + ) + + assert result == "您好,请问有什么可以帮您?" + + @pytest.mark.asyncio + async def test_generate_step_content_flexible_mode(self, mock_session, mock_llm_client, sample_flow_flexible): + """Test that flexible mode generates script via LLM.""" + engine = FlowEngine(session=mock_session, llm_client=mock_llm_client) + + result = await engine._generate_step_content( + step=sample_flow_flexible.steps[0], + context={"inputs": []}, + history=[{"role": "user", "content": "你好"}], + ) + + assert result == "您好,请问您贵姓?" + mock_llm_client.generate_text.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_step_content_flexible_no_intent(self, mock_session, sample_flow_flexible): + """Test that flexible mode without intent falls back to fixed.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + step = { + "step_no": 1, + "script_mode": "flexible", + "content": "fallback content", + } + + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + + assert result == "fallback content" + + @pytest.mark.asyncio + async def test_generate_step_content_template_mode(self, mock_session, sample_flow_template): + """Test that template mode fills variables from context.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + result = await engine._generate_step_content( + step=sample_flow_template.steps[0], + context={ + "user_name": "张先生", + "inquiry_style": "想咨询产品", + }, + history=None, + ) + + assert result == "您好张先生,请问您想咨询产品?" + + @pytest.mark.asyncio + async def test_generate_step_content_unknown_mode(self, mock_session): + """Test that unknown mode falls back to fixed.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + step = { + "step_no": 1, + "script_mode": "unknown", + "content": "fallback content", + } + + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + + assert result == "fallback content" + + @pytest.mark.asyncio + async def test_generate_step_content_default_fixed(self, mock_session): + """Test that missing script_mode defaults to fixed.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + step = { + "step_no": 1, + "content": "default fixed content", + } + + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + + assert result == "default fixed content" + + def test_script_mode_enum_values(self): + """Test ScriptMode enum has correct values.""" + assert ScriptMode.FIXED.value == "fixed" + assert ScriptMode.FLEXIBLE.value == "flexible" + assert ScriptMode.TEMPLATE.value == "template" + + +@pytest.mark.skip(reason="FlowEngine._generate_step_content not implemented yet - AC-IDS-04") +class TestFlowEngineBackwardCompatibility: + """[AC-IDS-13] Test backward compatibility with existing flows.""" + + @pytest.fixture + def mock_session(self): + """Create mock database session.""" + session = MagicMock(spec=AsyncSession) + session.execute = AsyncMock() + session.add = MagicMock() + session.flush = AsyncMock() + return session + + @pytest.fixture + def legacy_flow(self): + """Create legacy flow without script_mode field.""" + flow = ScriptFlow( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="旧版流程", + steps=[ + { + "step_no": 1, + "content": "您好,请问有什么可以帮您?", + "wait_input": True, + }, + { + "step_no": 2, + "content": "感谢您的咨询!", + "wait_input": False, + }, + ], + is_enabled=True, + ) + return flow + + @pytest.mark.asyncio + async def test_legacy_flow_defaults_to_fixed(self, mock_session, legacy_flow): + """Test that legacy flow without script_mode uses fixed mode.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + for step in legacy_flow.steps: + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + assert result == step["content"] + + @pytest.mark.asyncio + async def test_legacy_flow_with_missing_fields(self, mock_session): + """Test that steps with missing optional fields work correctly.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + step = { + "step_no": 1, + "content": "simple content", + } + + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + + assert result == "simple content" + + +@pytest.mark.skip(reason="FlowEngine._generate_step_content not implemented yet - AC-IDS-04") +class TestFlowEngineFallback: + """[AC-IDS-05] Test fallback mechanism for script generation.""" + + @pytest.fixture + def mock_session(self): + """Create mock database session.""" + session = MagicMock(spec=AsyncSession) + session.execute = AsyncMock() + session.add = MagicMock() + session.flush = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_flexible_mode_fallback_on_no_llm(self, mock_session): + """Test that flexible mode falls back when no LLM client.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + step = { + "step_no": 1, + "script_mode": "flexible", + "intent": "获取用户姓名", + "content": "请问怎么称呼您?", + } + + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + + assert result == "请问怎么称呼您?" + + @pytest.mark.asyncio + async def test_template_mode_missing_variable_placeholder(self, mock_session): + """Test that missing template variables use placeholders.""" + engine = FlowEngine(session=mock_session, llm_client=None) + + step = { + "step_no": 1, + "script_mode": "template", + "content": "您好{unknown_var},请问有什么可以帮您?", + } + + result = await engine._generate_step_content( + step=step, + context=None, + history=None, + ) + + assert result == "您好[unknown_var],请问有什么可以帮您?" diff --git a/ai-service/tests/test_script_generator.py b/ai-service/tests/test_script_generator.py new file mode 100644 index 0000000..12674fc --- /dev/null +++ b/ai-service/tests/test_script_generator.py @@ -0,0 +1,215 @@ +""" +Unit tests for ScriptGenerator. +[AC-IDS-04, AC-IDS-11] Test script generation for flexible mode. +""" + +import asyncio +import pytest + +from app.services.flow.script_generator import ScriptGenerator + + +class MockLLMClient: + """Mock LLM client for testing.""" + + def __init__(self, response: str = "您好,请问怎么称呼您?", delay: float = 0): + self._response = response + self._delay = delay + + async def generate_text(self, prompt: str) -> str: + if self._delay > 0: + await asyncio.sleep(self._delay) + return self._response + + async def generate(self, messages: list) -> "MockResponse": + if self._delay > 0: + await asyncio.sleep(self._delay) + return MockResponse(self._response) + + +class MockResponse: + """Mock LLM response.""" + def __init__(self, content: str): + self.content = content + + +class TestScriptGenerator: + """[AC-IDS-04, AC-IDS-11] Test cases for ScriptGenerator.""" + + @pytest.mark.asyncio + async def test_generate_fixed_mode_returns_fallback(self): + """Test that fixed mode returns fallback content.""" + generator = ScriptGenerator(llm_client=None) + + result = await generator.generate( + intent="获取用户姓名", + intent_description="礼貌询问用户姓名", + constraints=["必须礼貌"], + context=None, + history=None, + fallback="请问怎么称呼您?", + ) + + assert result == "请问怎么称呼您?" + + @pytest.mark.asyncio + async def test_generate_with_llm_client(self): + """Test script generation with LLM client.""" + llm_client = MockLLMClient(response="您好,请问您贵姓?") + generator = ScriptGenerator(llm_client=llm_client) + + result = await generator.generate( + intent="获取用户姓名", + intent_description="礼貌询问用户姓名", + constraints=["必须礼貌", "语气自然"], + context={"inputs": [{"step": 1, "input": "我想咨询"}]}, + history=[{"role": "user", "content": "我想咨询"}], + fallback="请问怎么称呼您?", + ) + + assert result == "您好,请问您贵姓?" + + @pytest.mark.asyncio + async def test_generate_timeout_fallback(self): + """Test that timeout returns fallback content.""" + llm_client = MockLLMClient(response="生成的话术", delay=6.0) + generator = ScriptGenerator(llm_client=llm_client) + + result = await generator.generate( + intent="获取用户姓名", + intent_description=None, + constraints=None, + context=None, + history=None, + fallback="请问怎么称呼您?", + ) + + assert result == "请问怎么称呼您?" + + @pytest.mark.asyncio + async def test_generate_exception_fallback(self): + """Test that exception returns fallback content.""" + class FailingLLMClient: + async def generate_text(self, prompt: str) -> str: + raise RuntimeError("LLM service unavailable") + + generator = ScriptGenerator(llm_client=FailingLLMClient()) + + result = await generator.generate( + intent="获取用户姓名", + intent_description=None, + constraints=None, + context=None, + history=None, + fallback="请问怎么称呼您?", + ) + + assert result == "请问怎么称呼您?" + + def test_build_prompt_basic(self): + """Test prompt building with basic parameters.""" + generator = ScriptGenerator(llm_client=None) + + prompt = generator._build_prompt( + intent="获取用户姓名", + intent_description=None, + constraints=None, + context=None, + history=None, + ) + + assert "获取用户姓名" in prompt + assert "步骤目标" in prompt + + def test_build_prompt_with_description(self): + """Test prompt building with intent description.""" + generator = ScriptGenerator(llm_client=None) + + prompt = generator._build_prompt( + intent="获取用户姓名", + intent_description="需要获取用户的真实姓名用于后续身份确认", + constraints=None, + context=None, + history=None, + ) + + assert "获取用户姓名" in prompt + assert "需要获取用户的真实姓名用于后续身份确认" in prompt + assert "详细说明" in prompt + + def test_build_prompt_with_constraints(self): + """Test prompt building with constraints.""" + generator = ScriptGenerator(llm_client=None) + + prompt = generator._build_prompt( + intent="获取用户姓名", + intent_description=None, + constraints=["必须礼貌", "语气自然", "不要生硬"], + context=None, + history=None, + ) + + assert "约束条件" in prompt + assert "- 必须礼貌" in prompt + assert "- 语气自然" in prompt + assert "- 不要生硬" in prompt + + def test_build_prompt_with_history(self): + """Test prompt building with conversation history.""" + generator = ScriptGenerator(llm_client=None) + + prompt = generator._build_prompt( + intent="获取用户姓名", + intent_description=None, + constraints=None, + context=None, + history=[ + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "您好,有什么可以帮您?"}, + {"role": "user", "content": "我想咨询"}, + ], + ) + + assert "对话历史" in prompt + assert "用户: 你好" in prompt + assert "客服: 您好,有什么可以帮您?" in prompt + + def test_build_prompt_with_context(self): + """Test prompt building with session context.""" + generator = ScriptGenerator(llm_client=None) + + prompt = generator._build_prompt( + intent="获取用户姓名", + intent_description=None, + constraints=None, + context={ + "inputs": [ + {"step": 1, "input": "我想咨询产品"}, + {"step": 2, "input": "手机"}, + ] + }, + history=None, + ) + + assert "已收集信息" in prompt + assert "步骤1: 我想咨询产品" in prompt + assert "步骤2: 手机" in prompt + + def test_build_prompt_complete(self): + """Test prompt building with all parameters.""" + generator = ScriptGenerator(llm_client=None) + + prompt = generator._build_prompt( + intent="获取用户姓名", + intent_description="需要获取用户的真实姓名", + constraints=["必须礼貌", "语气自然"], + context={"inputs": [{"step": 1, "input": "咨询"}]}, + history=[{"role": "user", "content": "你好"}], + ) + + assert "步骤目标" in prompt + assert "详细说明" in prompt + assert "约束条件" in prompt + assert "对话历史" in prompt + assert "已收集信息" in prompt + assert "不超过200字" in prompt diff --git a/ai-service/tests/test_template_engine.py b/ai-service/tests/test_template_engine.py new file mode 100644 index 0000000..9b4def7 --- /dev/null +++ b/ai-service/tests/test_template_engine.py @@ -0,0 +1,178 @@ +""" +Unit tests for TemplateEngine. +[AC-IDS-06, AC-IDS-11] Test template variable filling. +""" + +import pytest + +from app.services.flow.template_engine import TemplateEngine + + +class MockLLMClient: + """Mock LLM client for testing.""" + + def __init__(self, response: str = "先生"): + self._response = response + + async def generate_text(self, prompt: str) -> str: + return self._response + + async def generate(self, messages: list) -> "MockResponse": + return MockResponse(self._response) + + +class MockResponse: + """Mock LLM response.""" + def __init__(self, content: str): + self.content = content + + +class TestTemplateEngine: + """[AC-IDS-06, AC-IDS-11] Test cases for TemplateEngine.""" + + @pytest.mark.asyncio + async def test_fill_template_no_variables(self): + """Test template without variables returns as-is.""" + engine = TemplateEngine(llm_client=None) + + result = await engine.fill_template( + template="您好,请问有什么可以帮您?", + context=None, + history=None, + ) + + assert result == "您好,请问有什么可以帮您?" + + @pytest.mark.asyncio + async def test_fill_template_from_context(self): + """Test filling variables from context.""" + engine = TemplateEngine(llm_client=None) + + result = await engine.fill_template( + template="您好{user_name},请问有什么可以帮您?", + context={"user_name": "张先生"}, + history=None, + ) + + assert result == "您好张先生,请问有什么可以帮您?" + + @pytest.mark.asyncio + async def test_fill_template_from_inputs(self): + """Test filling variables from context inputs.""" + engine = TemplateEngine(llm_client=None) + + result = await engine.fill_template( + template="您好,您咨询的是{product}相关的问题吗?", + context={ + "inputs": [ + {"variable": "product", "input": "手机"}, + ] + }, + history=None, + ) + + assert result == "您好,您咨询的是手机相关的问题吗?" + + @pytest.mark.asyncio + async def test_fill_template_with_llm(self): + """Test filling variables using LLM generation.""" + llm_client = MockLLMClient(response="先生") + engine = TemplateEngine(llm_client=llm_client) + + result = await engine.fill_template( + template="您好{greeting_style},请问您{inquiry_style}?", + context=None, + history=[{"role": "user", "content": "你好"}], + ) + + assert "先生" in result + + @pytest.mark.asyncio + async def test_fill_template_multiple_variables(self): + """Test filling multiple variables.""" + engine = TemplateEngine(llm_client=None) + + result = await engine.fill_template( + template="您好{name},您订购的{product}已发货,预计{date}送达。", + context={ + "name": "李女士", + "product": "iPhone 15", + "date": "明天", + }, + history=None, + ) + + assert result == "您好李女士,您订购的iPhone 15已发货,预计明天送达。" + + @pytest.mark.asyncio + async def test_fill_template_missing_variable(self): + """Test handling missing variables with placeholder.""" + engine = TemplateEngine(llm_client=None) + + result = await engine.fill_template( + template="您好{unknown_var},请问有什么可以帮您?", + context=None, + history=None, + ) + + assert result == "您好[unknown_var],请问有什么可以帮您?" + + def test_extract_variables(self): + """Test extracting variable names from template.""" + engine = TemplateEngine(llm_client=None) + + variables = engine.extract_variables( + "您好{name},您订购的{product}预计{date}送达。" + ) + + assert variables == ["name", "product", "date"] + + def test_extract_variables_empty(self): + """Test extracting from template without variables.""" + engine = TemplateEngine(llm_client=None) + + variables = engine.extract_variables("您好,请问有什么可以帮您?") + + assert variables == [] + + def test_extract_variables_adjacent(self): + """Test extracting adjacent variables.""" + engine = TemplateEngine(llm_client=None) + + variables = engine.extract_variables("{a}{b}{c}") + + assert variables == ["a", "b", "c"] + + @pytest.mark.asyncio + async def test_fill_template_with_history_context(self): + """Test that history is used for LLM prompt.""" + llm_client = MockLLMClient(response="贵姓") + engine = TemplateEngine(llm_client=llm_client) + + result = await engine.fill_template( + template="您好,请问您{inquiry_style}?", + context=None, + history=[ + {"role": "user", "content": "我想咨询一下"}, + {"role": "assistant", "content": "好的,请问您想咨询什么?"}, + ], + ) + + assert "贵姓" in result + + @pytest.mark.asyncio + async def test_fill_template_exception_handling(self): + """Test that exceptions are handled gracefully.""" + class FailingLLMClient: + async def generate_text(self, prompt: str) -> str: + raise RuntimeError("LLM service unavailable") + + engine = TemplateEngine(llm_client=FailingLLMClient()) + + result = await engine.fill_template( + template="您好{greeting},请问有什么可以帮您?", + context=None, + history=None, + ) + + assert result == "您好[greeting],请问有什么可以帮您?"