""" Orchestrator service for AI Service. [AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation. Design reference: design.md Section 2.2 - 关键数据流 1. Memory.load(tenantId, sessionId) 2. merge_context(local_history, external_history) 3. Retrieval.retrieve(query, tenantId, channelType, metadata) 4. build_prompt(merged_history, retrieved_docs, currentMessage) 5. LLM.generate(...) (non-streaming) or LLM.stream_generate(...) (streaming) 6. compute_confidence(...) 7. Memory.append(tenantId, sessionId, user/assistant messages) 8. Return ChatResponse (or output via SSE) """ import logging from dataclasses import dataclass, field from typing import Any, AsyncGenerator from sse_starlette.sse import ServerSentEvent from app.core.config import get_settings from app.core.sse import ( create_error_event, create_final_event, create_message_event, SSEStateMachine, ) from app.models import ChatRequest, ChatResponse from app.services.confidence import ConfidenceCalculator, ConfidenceResult from app.services.context import ContextMerger, MergedContext from app.services.llm.base import LLMClient, LLMConfig, LLMResponse from app.services.memory import MemoryService from app.services.retrieval.base import BaseRetriever, RetrievalContext, RetrievalResult logger = logging.getLogger(__name__) @dataclass class OrchestratorConfig: """ Configuration for OrchestratorService. [AC-AISVC-01] Centralized configuration for orchestration. """ max_history_tokens: int = 4000 max_evidence_tokens: int = 2000 system_prompt: str = "你是一个智能客服助手,请根据提供的知识库内容回答用户问题。" enable_rag: bool = True @dataclass class GenerationContext: """ [AC-AISVC-01, AC-AISVC-02] Context accumulated during generation pipeline. Contains all intermediate results for diagnostics and response building. """ tenant_id: str session_id: str current_message: str channel_type: str request_metadata: dict[str, Any] | None = None local_history: list[dict[str, str]] = field(default_factory=list) merged_context: MergedContext | None = None retrieval_result: RetrievalResult | None = None llm_response: LLMResponse | None = None confidence_result: ConfidenceResult | None = None diagnostics: dict[str, Any] = field(default_factory=dict) class OrchestratorService: """ [AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation. Coordinates memory, retrieval, and LLM components. SSE Event Flow (per design.md Section 6.2): - message* (0 or more) -> final (exactly 1) -> close - OR message* (0 or more) -> error (exactly 1) -> close """ def __init__( self, llm_client: LLMClient | None = None, memory_service: MemoryService | None = None, retriever: BaseRetriever | None = None, context_merger: ContextMerger | None = None, confidence_calculator: ConfidenceCalculator | None = None, config: OrchestratorConfig | None = None, ): """ Initialize orchestrator with optional dependencies for DI. Args: llm_client: LLM client for generation memory_service: Memory service for session history retriever: Retriever for RAG context_merger: Context merger for history deduplication confidence_calculator: Confidence calculator for response scoring config: Orchestrator configuration """ settings = get_settings() self._llm_client = llm_client self._memory_service = memory_service self._retriever = retriever self._context_merger = context_merger or ContextMerger( max_history_tokens=getattr(settings, "max_history_tokens", 4000) ) self._confidence_calculator = confidence_calculator or ConfidenceCalculator() self._config = config or OrchestratorConfig( max_history_tokens=getattr(settings, "max_history_tokens", 4000), max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000), enable_rag=True, ) self._llm_config = LLMConfig( model=getattr(settings, "llm_model", "gpt-4o-mini"), max_tokens=getattr(settings, "llm_max_tokens", 2048), temperature=getattr(settings, "llm_temperature", 0.7), timeout_seconds=getattr(settings, "llm_timeout_seconds", 30), max_retries=getattr(settings, "llm_max_retries", 3), ) async def generate( self, tenant_id: str, request: ChatRequest, ) -> ChatResponse: """ Generate a non-streaming response. [AC-AISVC-01, AC-AISVC-02] Complete generation pipeline. Pipeline (per design.md Section 2.2): 1. Load local history from Memory 2. Merge with external history (dedup + truncate) 3. RAG retrieval (optional) 4. Build prompt with context and evidence 5. LLM generation 6. Calculate confidence 7. Save messages to Memory 8. Return ChatResponse """ logger.info( f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, " f"session={request.session_id}" ) ctx = GenerationContext( tenant_id=tenant_id, session_id=request.session_id, current_message=request.current_message, channel_type=request.channel_type.value, request_metadata=request.metadata, ) try: await self._load_local_history(ctx) await self._merge_context(ctx, request.history) if self._config.enable_rag and self._retriever: await self._retrieve_evidence(ctx) await self._generate_response(ctx) self._calculate_confidence(ctx) await self._save_messages(ctx) return self._build_response(ctx) except Exception as e: logger.error(f"[AC-AISVC-01] Generation failed: {e}") return ChatResponse( reply="抱歉,服务暂时不可用,请稍后重试或联系人工客服。", confidence=0.0, should_transfer=True, transfer_reason=f"服务异常: {str(e)}", metadata={"error": str(e), "diagnostics": ctx.diagnostics}, ) async def _load_local_history(self, ctx: GenerationContext) -> None: """ [AC-AISVC-13] Load local history from Memory service. Step 1 of the generation pipeline. """ if not self._memory_service: logger.info("[AC-AISVC-13] No memory service configured, skipping history load") ctx.diagnostics["memory_enabled"] = False return try: messages = await self._memory_service.load_history( tenant_id=ctx.tenant_id, session_id=ctx.session_id, ) ctx.local_history = [ {"role": msg.role, "content": msg.content} for msg in messages ] ctx.diagnostics["memory_enabled"] = True ctx.diagnostics["local_history_count"] = len(ctx.local_history) logger.info( f"[AC-AISVC-13] Loaded {len(ctx.local_history)} messages from memory " f"for tenant={ctx.tenant_id}, session={ctx.session_id}" ) except Exception as e: logger.warning(f"[AC-AISVC-13] Failed to load history: {e}") ctx.diagnostics["memory_error"] = str(e) async def _merge_context( self, ctx: GenerationContext, external_history: list | None, ) -> None: """ [AC-AISVC-14, AC-AISVC-15] Merge local and external history. Step 2 of the generation pipeline. Design reference: design.md Section 7 - Deduplication based on fingerprint - Truncation to fit token budget """ external_messages = None if external_history: external_messages = [ {"role": msg.role.value, "content": msg.content} for msg in external_history ] ctx.merged_context = self._context_merger.merge_and_truncate( local_history=ctx.local_history, external_history=external_messages, max_tokens=self._config.max_history_tokens, ) ctx.diagnostics["merged_context"] = { "local_count": ctx.merged_context.local_count, "external_count": ctx.merged_context.external_count, "duplicates_skipped": ctx.merged_context.duplicates_skipped, "truncated_count": ctx.merged_context.truncated_count, "total_tokens": ctx.merged_context.total_tokens, } logger.info( f"[AC-AISVC-14, AC-AISVC-15] Context merged: " f"local={ctx.merged_context.local_count}, " f"external={ctx.merged_context.external_count}, " f"tokens={ctx.merged_context.total_tokens}" ) async def _retrieve_evidence(self, ctx: GenerationContext) -> None: """ [AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence. Step 3 of the generation pipeline. """ try: retrieval_ctx = RetrievalContext( tenant_id=ctx.tenant_id, query=ctx.current_message, session_id=ctx.session_id, channel_type=ctx.channel_type, metadata=ctx.request_metadata, ) ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx) ctx.diagnostics["retrieval"] = { "hit_count": ctx.retrieval_result.hit_count, "max_score": ctx.retrieval_result.max_score, "is_empty": ctx.retrieval_result.is_empty, } logger.info( f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: " f"hits={ctx.retrieval_result.hit_count}, " f"max_score={ctx.retrieval_result.max_score:.3f}" ) except Exception as e: logger.warning(f"[AC-AISVC-16] Retrieval failed: {e}") ctx.retrieval_result = RetrievalResult( hits=[], diagnostics={"error": str(e)}, ) ctx.diagnostics["retrieval_error"] = str(e) async def _generate_response(self, ctx: GenerationContext) -> None: """ [AC-AISVC-02] Generate response using LLM. Step 4-5 of the generation pipeline. """ messages = self._build_llm_messages(ctx) if not self._llm_client: logger.warning("[AC-AISVC-02] No LLM client configured, using fallback") ctx.llm_response = LLMResponse( content=self._fallback_response(ctx), model="fallback", usage={}, finish_reason="fallback", ) ctx.diagnostics["llm_mode"] = "fallback" return try: ctx.llm_response = await self._llm_client.generate( messages=messages, config=self._llm_config, ) ctx.diagnostics["llm_mode"] = "live" ctx.diagnostics["llm_model"] = ctx.llm_response.model ctx.diagnostics["llm_usage"] = ctx.llm_response.usage logger.info( f"[AC-AISVC-02] LLM response generated: " f"model={ctx.llm_response.model}, " f"tokens={ctx.llm_response.usage}" ) except Exception as e: logger.error(f"[AC-AISVC-02] LLM generation failed: {e}") ctx.llm_response = LLMResponse( content=self._fallback_response(ctx), model="fallback", usage={}, finish_reason="error", metadata={"error": str(e)}, ) ctx.diagnostics["llm_error"] = str(e) def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]: """ [AC-AISVC-02] Build messages for LLM including system prompt and evidence. """ messages = [] system_content = self._config.system_prompt if ctx.retrieval_result and not ctx.retrieval_result.is_empty: evidence_text = self._format_evidence(ctx.retrieval_result) system_content += f"\n\n知识库参考内容:\n{evidence_text}" messages.append({"role": "system", "content": system_content}) if ctx.merged_context and ctx.merged_context.messages: messages.extend(ctx.merged_context.messages) messages.append({"role": "user", "content": ctx.current_message}) return messages def _format_evidence(self, retrieval_result: RetrievalResult) -> str: """ [AC-AISVC-17] Format retrieval hits as evidence text. """ evidence_parts = [] for i, hit in enumerate(retrieval_result.hits[:5], 1): evidence_parts.append(f"[{i}] (相关度: {hit.score:.2f}) {hit.text}") return "\n".join(evidence_parts) def _fallback_response(self, ctx: GenerationContext) -> str: """ [AC-AISVC-17] Generate fallback response when LLM is unavailable. """ if ctx.retrieval_result and not ctx.retrieval_result.is_empty: return ( "根据知识库信息,我找到了一些相关内容," "但暂时无法生成完整回复。建议您稍后重试或联系人工客服。" ) return ( "抱歉,我暂时无法处理您的请求。" "请稍后重试或联系人工客服获取帮助。" ) def _calculate_confidence(self, ctx: GenerationContext) -> None: """ [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score. Step 6 of the generation pipeline. """ if ctx.retrieval_result: evidence_tokens = 0 if not ctx.retrieval_result.is_empty: evidence_tokens = sum( len(hit.text.split()) * 2 for hit in ctx.retrieval_result.hits ) ctx.confidence_result = self._confidence_calculator.calculate_confidence( retrieval_result=ctx.retrieval_result, evidence_tokens=evidence_tokens, ) else: ctx.confidence_result = self._confidence_calculator.calculate_confidence_no_retrieval() ctx.diagnostics["confidence"] = { "score": ctx.confidence_result.confidence, "should_transfer": ctx.confidence_result.should_transfer, "is_insufficient": ctx.confidence_result.is_retrieval_insufficient, } logger.info( f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: " f"{ctx.confidence_result.confidence:.3f}, " f"should_transfer={ctx.confidence_result.should_transfer}" ) async def _save_messages(self, ctx: GenerationContext) -> None: """ [AC-AISVC-13] Save user and assistant messages to Memory. Step 7 of the generation pipeline. """ if not self._memory_service: logger.info("[AC-AISVC-13] No memory service configured, skipping save") return try: await self._memory_service.get_or_create_session( tenant_id=ctx.tenant_id, session_id=ctx.session_id, channel_type=ctx.channel_type, metadata=ctx.request_metadata, ) messages_to_save = [ {"role": "user", "content": ctx.current_message}, ] if ctx.llm_response: messages_to_save.append({ "role": "assistant", "content": ctx.llm_response.content, }) await self._memory_service.append_messages( tenant_id=ctx.tenant_id, session_id=ctx.session_id, messages=messages_to_save, ) ctx.diagnostics["messages_saved"] = len(messages_to_save) logger.info( f"[AC-AISVC-13] Saved {len(messages_to_save)} messages " f"for tenant={ctx.tenant_id}, session={ctx.session_id}" ) except Exception as e: logger.warning(f"[AC-AISVC-13] Failed to save messages: {e}") ctx.diagnostics["save_error"] = str(e) def _build_response(self, ctx: GenerationContext) -> ChatResponse: """ [AC-AISVC-02] Build final ChatResponse from generation context. Step 8 of the generation pipeline. """ reply = ctx.llm_response.content if ctx.llm_response else self._fallback_response(ctx) confidence = ctx.confidence_result.confidence if ctx.confidence_result else 0.5 should_transfer = ctx.confidence_result.should_transfer if ctx.confidence_result else True transfer_reason = ctx.confidence_result.transfer_reason if ctx.confidence_result else None response_metadata = { "session_id": ctx.session_id, "channel_type": ctx.channel_type, "diagnostics": ctx.diagnostics, } return ChatResponse( reply=reply, confidence=confidence, should_transfer=should_transfer, transfer_reason=transfer_reason, metadata=response_metadata, ) async def generate_stream( self, tenant_id: str, request: ChatRequest, ) -> AsyncGenerator[ServerSentEvent, None]: """ Generate a streaming response. [AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence. SSE Event Sequence (per design.md Section 6.2): 1. message events (multiple) - each with incremental delta 2. final event (exactly 1) - with complete response 3. connection close OR on error: 1. message events (0 or more) 2. error event (exactly 1) 3. connection close """ logger.info( f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, " f"session={request.session_id}" ) state_machine = SSEStateMachine() await state_machine.transition_to_streaming() ctx = GenerationContext( tenant_id=tenant_id, session_id=request.session_id, current_message=request.current_message, channel_type=request.channel_type.value, request_metadata=request.metadata, ) try: await self._load_local_history(ctx) await self._merge_context(ctx, request.history) if self._config.enable_rag and self._retriever: await self._retrieve_evidence(ctx) full_reply = "" if self._llm_client: async for event in self._stream_from_llm(ctx, state_machine): if event.event == "message": full_reply += self._extract_delta_from_event(event) yield event else: async for event in self._stream_mock_response(ctx, state_machine): if event.event == "message": full_reply += self._extract_delta_from_event(event) yield event if ctx.llm_response is None: ctx.llm_response = LLMResponse( content=full_reply, model="streaming", usage={}, finish_reason="stop", ) self._calculate_confidence(ctx) await self._save_messages(ctx) if await state_machine.transition_to_final(): yield create_final_event( reply=full_reply, confidence=ctx.confidence_result.confidence if ctx.confidence_result else 0.5, should_transfer=ctx.confidence_result.should_transfer if ctx.confidence_result else False, transfer_reason=ctx.confidence_result.transfer_reason if ctx.confidence_result else None, ) except Exception as e: logger.error(f"[AC-AISVC-09] Error during streaming: {e}") if await state_machine.transition_to_error(): yield create_error_event( code="GENERATION_ERROR", message=str(e), ) finally: await state_machine.close() async def _stream_from_llm( self, ctx: GenerationContext, state_machine: SSEStateMachine, ) -> AsyncGenerator[ServerSentEvent, None]: """ [AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event. """ messages = self._build_llm_messages(ctx) async for chunk in self._llm_client.stream_generate(messages, self._llm_config): if not state_machine.can_send_message(): break if chunk.delta: logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...") yield create_message_event(delta=chunk.delta) if chunk.finish_reason: logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}") break async def _stream_mock_response( self, ctx: GenerationContext, state_machine: SSEStateMachine, ) -> AsyncGenerator[ServerSentEvent, None]: """ [AC-AISVC-07] Mock streaming response for demo/testing purposes. Simulates LLM-style incremental output. """ import asyncio reply_parts = ["收到", "您的", "消息:", f" {ctx.current_message}"] for part in reply_parts: if not state_machine.can_send_message(): break logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}") yield create_message_event(delta=part) await asyncio.sleep(0.05) def _extract_delta_from_event(self, event: ServerSentEvent) -> str: """Extract delta content from a message event.""" import json try: if event.data: data = json.loads(event.data) return data.get("delta", "") except (json.JSONDecodeError, TypeError): pass return "" _orchestrator_service: OrchestratorService | None = None def get_orchestrator_service() -> OrchestratorService: """Get or create orchestrator service instance.""" global _orchestrator_service if _orchestrator_service is None: _orchestrator_service = OrchestratorService() return _orchestrator_service def set_orchestrator_service(service: OrchestratorService) -> None: """Set orchestrator service instance for testing.""" global _orchestrator_service _orchestrator_service = service