""" Tests for OrchestratorService. [AC-AISVC-01, AC-AISVC-02] Test complete generation pipeline integration. """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from typing import AsyncGenerator from app.models import ChatRequest, ChatResponse, ChannelType, ChatMessage, Role from app.services.orchestrator import ( OrchestratorService, OrchestratorConfig, GenerationContext, set_orchestrator_service, ) from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk from app.services.memory import MemoryService from app.services.retrieval.base import ( BaseRetriever, RetrievalContext, RetrievalResult, RetrievalHit, ) from app.services.confidence import ConfidenceCalculator, ConfidenceConfig from app.services.context import ContextMerger from app.models.entities import ChatMessage as ChatMessageEntity class MockLLMClient(LLMClient): """Mock LLM client for testing.""" def __init__(self, response_content: str = "Mock LLM response"): self._response_content = response_content self._generate_called = False self._stream_generate_called = False async def generate( self, messages: list[dict[str, str]], config: LLMConfig | None = None, **kwargs, ) -> LLMResponse: self._generate_called = True return LLMResponse( content=self._response_content, model="mock-model", usage={"prompt_tokens": 100, "completion_tokens": 50}, finish_reason="stop", ) async def stream_generate( self, messages: list[dict[str, str]], config: LLMConfig | None = None, **kwargs, ) -> AsyncGenerator[LLMStreamChunk, None]: self._stream_generate_called = True chunks = ["Hello", " from", " mock", " LLM"] for chunk in chunks: yield LLMStreamChunk(delta=chunk, model="mock-model") yield LLMStreamChunk(delta="", model="mock-model", finish_reason="stop") async def close(self) -> None: pass class MockRetriever(BaseRetriever): """Mock retriever for testing.""" def __init__(self, hits: list[RetrievalHit] | None = None): self._hits = hits or [] async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult: return RetrievalResult( hits=self._hits, diagnostics={"mock": True}, ) async def health_check(self) -> bool: return True class MockMemoryService: """Mock memory service for testing.""" def __init__(self, history: list[ChatMessageEntity] | None = None): self._history = history or [] self._saved_messages: list[dict] = [] self._session_created = False async def get_or_create_session( self, tenant_id: str, session_id: str, channel_type: str | None = None, metadata: dict | None = None, ): self._session_created = True return MagicMock(tenant_id=tenant_id, session_id=session_id) async def load_history( self, tenant_id: str, session_id: str, limit: int | None = None, ): return self._history async def append_message( self, tenant_id: str, session_id: str, role: str, content: str, ): self._saved_messages.append({"role": role, "content": content}) async def append_messages( self, tenant_id: str, session_id: str, messages: list[dict[str, str]], ): self._saved_messages.extend(messages) def create_chat_request( message: str = "Hello", session_id: str = "test-session", history: list[ChatMessage] | None = None, metadata: dict | None = None, ) -> ChatRequest: """Helper to create ChatRequest.""" return ChatRequest( session_id=session_id, current_message=message, channel_type=ChannelType.WECHAT, history=history, metadata=metadata, ) class TestOrchestratorServiceGenerate: """Tests for OrchestratorService.generate() method.""" @pytest.mark.asyncio async def test_generate_basic_without_dependencies(self): """ [AC-AISVC-01, AC-AISVC-02] Test basic generation without external dependencies. Should return fallback response with low confidence. """ orchestrator = OrchestratorService( config=OrchestratorConfig(enable_rag=False), ) request = create_chat_request(message="What is the price?") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert isinstance(response, ChatResponse) assert response.reply is not None assert response.confidence >= 0.0 assert response.confidence <= 1.0 assert isinstance(response.should_transfer, bool) assert "diagnostics" in response.metadata @pytest.mark.asyncio async def test_generate_with_llm_client(self): """ [AC-AISVC-02] Test generation with LLM client. Should use LLM response. """ mock_llm = MockLLMClient(response_content="This is the AI response.") orchestrator = OrchestratorService( llm_client=mock_llm, config=OrchestratorConfig(enable_rag=False), ) request = create_chat_request(message="Hello") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert response.reply == "This is the AI response." assert mock_llm._generate_called is True @pytest.mark.asyncio async def test_generate_with_memory_service(self): """ [AC-AISVC-13] Test generation with memory service. Should load history and save messages. """ mock_memory = MockMemoryService( history=[ ChatMessageEntity( tenant_id="tenant-1", session_id="test-session", role="user", content="Previous message", ) ] ) mock_llm = MockLLMClient() orchestrator = OrchestratorService( llm_client=mock_llm, memory_service=mock_memory, config=OrchestratorConfig(enable_rag=False), ) request = create_chat_request(message="New message") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert len(mock_memory._saved_messages) == 2 assert mock_memory._saved_messages[0]["role"] == "user" assert mock_memory._saved_messages[1]["role"] == "assistant" @pytest.mark.asyncio async def test_generate_with_retrieval(self): """ [AC-AISVC-16, AC-AISVC-17] Test generation with RAG retrieval. Should include evidence in LLM prompt. """ mock_retriever = MockRetriever( hits=[ RetrievalHit( text="Product price is $100", score=0.85, source="kb", ) ] ) mock_llm = MockLLMClient() orchestrator = OrchestratorService( llm_client=mock_llm, retriever=mock_retriever, config=OrchestratorConfig(enable_rag=True), ) request = create_chat_request(message="What is the price?") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert "retrieval" in response.metadata["diagnostics"] assert response.metadata["diagnostics"]["retrieval"]["hit_count"] == 1 @pytest.mark.asyncio async def test_generate_with_context_merging(self): """ [AC-AISVC-14, AC-AISVC-15] Test context merging with external history. Should merge local and external history. """ mock_memory = MockMemoryService( history=[ ChatMessageEntity( tenant_id="tenant-1", session_id="test-session", role="user", content="Local message", ) ] ) mock_llm = MockLLMClient() orchestrator = OrchestratorService( llm_client=mock_llm, memory_service=mock_memory, config=OrchestratorConfig(enable_rag=False), ) request = create_chat_request( message="New message", history=[ ChatMessage(role=Role.USER, content="External message"), ChatMessage(role=Role.ASSISTANT, content="External response"), ], ) response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert "merged_context" in response.metadata["diagnostics"] merged = response.metadata["diagnostics"]["merged_context"] assert merged["local_count"] == 1 assert merged["external_count"] == 2 @pytest.mark.asyncio async def test_generate_with_confidence_calculation(self): """ [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Test confidence calculation. Should calculate confidence based on retrieval results. """ mock_retriever = MockRetriever( hits=[ RetrievalHit(text="High relevance content", score=0.9, source="kb"), RetrievalHit(text="Medium relevance", score=0.8, source="kb"), ] ) mock_llm = MockLLMClient() orchestrator = OrchestratorService( llm_client=mock_llm, retriever=mock_retriever, config=OrchestratorConfig(enable_rag=True), ) request = create_chat_request(message="Test query") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert response.confidence > 0.5 assert "confidence" in response.metadata["diagnostics"] @pytest.mark.asyncio async def test_generate_low_confidence_triggers_transfer(self): """ [AC-AISVC-18, AC-AISVC-19] Test low confidence triggers transfer. Should set should_transfer=True when confidence is low. """ mock_retriever = MockRetriever(hits=[]) mock_llm = MockLLMClient() orchestrator = OrchestratorService( llm_client=mock_llm, retriever=mock_retriever, config=OrchestratorConfig(enable_rag=True), ) request = create_chat_request(message="Unknown topic") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert response.should_transfer is True assert response.transfer_reason is not None @pytest.mark.asyncio async def test_generate_handles_llm_error(self): """ [AC-AISVC-02] Test handling of LLM errors. Should return fallback response on error. """ mock_llm = MagicMock() mock_llm.generate = AsyncMock(side_effect=Exception("LLM unavailable")) orchestrator = OrchestratorService( llm_client=mock_llm, config=OrchestratorConfig(enable_rag=False), ) request = create_chat_request(message="Hello") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert response.reply is not None assert "llm_error" in response.metadata["diagnostics"] @pytest.mark.asyncio async def test_generate_handles_retrieval_error(self): """ [AC-AISVC-16] Test handling of retrieval errors. Should continue with empty retrieval result. """ mock_retriever = MagicMock() mock_retriever.retrieve = AsyncMock(side_effect=Exception("Qdrant unavailable")) mock_llm = MockLLMClient() orchestrator = OrchestratorService( llm_client=mock_llm, retriever=mock_retriever, config=OrchestratorConfig(enable_rag=True), ) request = create_chat_request(message="Hello") response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert response.reply == "Mock LLM response" assert "retrieval_error" in response.metadata["diagnostics"] @pytest.mark.asyncio async def test_generate_full_pipeline_integration(self): """ [AC-AISVC-01, AC-AISVC-02] Test complete pipeline integration. All components working together. """ mock_memory = MockMemoryService( history=[ ChatMessageEntity( tenant_id="tenant-1", session_id="test-session", role="user", content="Previous question", ), ChatMessageEntity( tenant_id="tenant-1", session_id="test-session", role="assistant", content="Previous answer", ), ] ) mock_retriever = MockRetriever( hits=[ RetrievalHit(text="Knowledge base content", score=0.85, source="kb"), ] ) mock_llm = MockLLMClient(response_content="AI generated response") orchestrator = OrchestratorService( llm_client=mock_llm, memory_service=mock_memory, retriever=mock_retriever, config=OrchestratorConfig(enable_rag=True), ) request = create_chat_request( message="New question", history=[ ChatMessage(role=Role.USER, content="External history"), ], ) response = await orchestrator.generate( tenant_id="tenant-1", request=request, ) assert response.reply == "AI generated response" assert response.confidence > 0.0 assert len(mock_memory._saved_messages) == 2 diagnostics = response.metadata["diagnostics"] assert diagnostics["memory_enabled"] is True assert diagnostics["retrieval"]["hit_count"] == 1 assert diagnostics["llm_mode"] == "live" class TestOrchestratorServiceGenerationContext: """Tests for GenerationContext dataclass.""" def test_generation_context_initialization(self): """Test GenerationContext initialization.""" ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="Hello", channel_type="wechat", ) assert ctx.tenant_id == "tenant-1" assert ctx.session_id == "session-1" assert ctx.current_message == "Hello" assert ctx.channel_type == "wechat" assert ctx.local_history == [] assert ctx.diagnostics == {} def test_generation_context_with_metadata(self): """Test GenerationContext with metadata.""" ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="Hello", channel_type="wechat", request_metadata={"user_id": "user-123"}, ) assert ctx.request_metadata == {"user_id": "user-123"} class TestOrchestratorConfig: """Tests for OrchestratorConfig dataclass.""" def test_default_config(self): """Test default configuration values.""" config = OrchestratorConfig() assert config.max_history_tokens == 4000 assert config.max_evidence_tokens == 2000 assert config.enable_rag is True assert "智能客服" in config.system_prompt def test_custom_config(self): """Test custom configuration values.""" config = OrchestratorConfig( max_history_tokens=8000, enable_rag=False, system_prompt="Custom prompt", ) assert config.max_history_tokens == 8000 assert config.enable_rag is False assert config.system_prompt == "Custom prompt" class TestOrchestratorServiceHelperMethods: """Tests for OrchestratorService helper methods.""" def test_build_llm_messages_basic(self): """Test _build_llm_messages with basic context.""" orchestrator = OrchestratorService( config=OrchestratorConfig(enable_rag=False), ) ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="Hello", channel_type="wechat", ) messages = orchestrator._build_llm_messages(ctx) assert len(messages) == 2 assert messages[0]["role"] == "system" assert messages[1]["role"] == "user" assert messages[1]["content"] == "Hello" def test_build_llm_messages_with_evidence(self): """Test _build_llm_messages includes evidence from retrieval.""" orchestrator = OrchestratorService( config=OrchestratorConfig(enable_rag=True), ) ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="What is the price?", channel_type="wechat", retrieval_result=RetrievalResult( hits=[ RetrievalHit(text="Price is $100", score=0.9, source="kb"), ] ), ) messages = orchestrator._build_llm_messages(ctx) assert "知识库参考内容" in messages[0]["content"] assert "Price is $100" in messages[0]["content"] def test_build_llm_messages_with_history(self): """Test _build_llm_messages includes merged history.""" from app.services.context import MergedContext orchestrator = OrchestratorService( config=OrchestratorConfig(enable_rag=False), ) ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="New question", channel_type="wechat", merged_context=MergedContext( messages=[ {"role": "user", "content": "Previous question"}, {"role": "assistant", "content": "Previous answer"}, ] ), ) messages = orchestrator._build_llm_messages(ctx) assert len(messages) == 4 assert messages[1]["role"] == "user" assert messages[1]["content"] == "Previous question" assert messages[2]["role"] == "assistant" assert messages[3]["role"] == "user" assert messages[3]["content"] == "New question" def test_fallback_response_with_evidence(self): """Test _fallback_response when retrieval has evidence.""" orchestrator = OrchestratorService() ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="Question", channel_type="wechat", retrieval_result=RetrievalResult( hits=[RetrievalHit(text="Evidence", score=0.8, source="kb")] ), ) fallback = orchestrator._fallback_response(ctx) assert "知识库" in fallback def test_fallback_response_without_evidence(self): """Test _fallback_response when no retrieval evidence.""" orchestrator = OrchestratorService() ctx = GenerationContext( tenant_id="tenant-1", session_id="session-1", current_message="Question", channel_type="wechat", retrieval_result=RetrievalResult(hits=[]), ) fallback = orchestrator._fallback_response(ctx) assert "无法处理" in fallback or "人工客服" in fallback def test_format_evidence(self): """Test _format_evidence formats hits correctly.""" orchestrator = OrchestratorService() result = RetrievalResult( hits=[ RetrievalHit(text="First result", score=0.9, source="kb"), RetrievalHit(text="Second result", score=0.8, source="kb"), ] ) formatted = orchestrator._format_evidence(result) assert "[1]" in formatted assert "[2]" in formatted assert "First result" in formatted assert "Second result" in formatted class TestOrchestratorServiceSetInstance: """Tests for set_orchestrator_service function.""" def test_set_orchestrator_service(self): """Test setting orchestrator service instance.""" custom_orchestrator = OrchestratorService( config=OrchestratorConfig(enable_rag=False), ) set_orchestrator_service(custom_orchestrator) from app.services.orchestrator import get_orchestrator_service instance = get_orchestrator_service() assert instance is custom_orchestrator