""" Orchestrator service for AI Service. [AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation. """ import logging from typing import AsyncGenerator from sse_starlette.sse import ServerSentEvent from app.models import ChatRequest, ChatResponse from app.core.sse import create_error_event, create_final_event, create_message_event, SSEStateMachine logger = logging.getLogger(__name__) class OrchestratorService: """ [AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation. Coordinates memory, retrieval, and LLM components. SSE Event Flow (per design.md Section 6.2): - message* (0 or more) -> final (exactly 1) -> close - OR message* (0 or more) -> error (exactly 1) -> close """ def __init__(self, llm_client=None): """ Initialize orchestrator with optional LLM client. Args: llm_client: Optional LLM client for dependency injection. If None, will use mock implementation for demo. """ self._llm_client = llm_client async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse: """ Generate a non-streaming response. [AC-AISVC-02] Returns ChatResponse with reply, confidence, shouldTransfer. """ logger.info( f"[AC-AISVC-01] Generating response for tenant={tenant_id}, " f"session={request.session_id}" ) if self._llm_client: messages = self._build_messages(request) response = await self._llm_client.generate(messages) return ChatResponse( reply=response.content, confidence=0.85, should_transfer=False, ) reply = f"Received your message: {request.current_message}" return ChatResponse( reply=reply, confidence=0.85, should_transfer=False, ) async def generate_stream( self, tenant_id: str, request: ChatRequest ) -> AsyncGenerator[ServerSentEvent, None]: """ Generate a streaming response. [AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence. SSE Event Sequence (per design.md Section 6.2): 1. message events (multiple) - each with incremental delta 2. final event (exactly 1) - with complete response 3. connection close OR on error: 1. message events (0 or more) 2. error event (exactly 1) 3. connection close """ logger.info( f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, " f"session={request.session_id}" ) state_machine = SSEStateMachine() await state_machine.transition_to_streaming() try: full_reply = "" if self._llm_client: async for event in self._stream_from_llm(request, state_machine): if event.event == "message": full_reply += self._extract_delta_from_event(event) yield event else: async for event in self._stream_mock_response(request, state_machine): if event.event == "message": full_reply += self._extract_delta_from_event(event) yield event if await state_machine.transition_to_final(): yield create_final_event( reply=full_reply, confidence=0.85, should_transfer=False, ) except Exception as e: logger.error(f"[AC-AISVC-09] Error during streaming: {e}") if await state_machine.transition_to_error(): yield create_error_event( code="GENERATION_ERROR", message=str(e), ) finally: await state_machine.close() async def _stream_from_llm( self, request: ChatRequest, state_machine: SSEStateMachine ) -> AsyncGenerator[ServerSentEvent, None]: """ [AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event. """ messages = self._build_messages(request) async for chunk in self._llm_client.stream_generate(messages): if not state_machine.can_send_message(): break if chunk.delta: logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...") yield create_message_event(delta=chunk.delta) if chunk.finish_reason: logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}") break async def _stream_mock_response( self, request: ChatRequest, state_machine: SSEStateMachine ) -> AsyncGenerator[ServerSentEvent, None]: """ [AC-AISVC-07] Mock streaming response for demo/testing purposes. Simulates LLM-style incremental output. """ import asyncio reply_parts = ["Received", " your", " message:", f" {request.current_message}"] for part in reply_parts: if not state_machine.can_send_message(): break logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}") yield create_message_event(delta=part) await asyncio.sleep(0.05) def _build_messages(self, request: ChatRequest) -> list[dict[str, str]]: """Build messages list for LLM from request.""" messages = [] if request.history: for msg in request.history: messages.append({ "role": msg.role.value, "content": msg.content, }) messages.append({ "role": "user", "content": request.current_message, }) return messages def _extract_delta_from_event(self, event: ServerSentEvent) -> str: """Extract delta content from a message event.""" import json try: if event.data: data = json.loads(event.data) return data.get("delta", "") except (json.JSONDecodeError, TypeError): pass return "" _orchestrator_service: OrchestratorService | None = None def get_orchestrator_service() -> OrchestratorService: """Get or create orchestrator service instance.""" global _orchestrator_service if _orchestrator_service is None: _orchestrator_service = OrchestratorService() return _orchestrator_service