702 lines
26 KiB
Python
702 lines
26 KiB
Python
"""
|
||
Orchestrator service for AI Service.
|
||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
|
||
|
||
Design reference: design.md Section 2.2 - 关键数据流
|
||
1. Memory.load(tenantId, sessionId)
|
||
2. merge_context(local_history, external_history)
|
||
3. Retrieval.retrieve(query, tenantId, channelType, metadata)
|
||
4. build_prompt(merged_history, retrieved_docs, currentMessage)
|
||
5. LLM.generate(...) (non-streaming) or LLM.stream_generate(...) (streaming)
|
||
6. compute_confidence(...)
|
||
7. Memory.append(tenantId, sessionId, user/assistant messages)
|
||
8. Return ChatResponse (or output via SSE)
|
||
|
||
RAG Optimization (rag-optimization/spec.md):
|
||
- Two-stage retrieval with Matryoshka dimensions
|
||
- RRF hybrid ranking
|
||
- Optimized prompt engineering
|
||
"""
|
||
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, AsyncGenerator
|
||
|
||
from sse_starlette.sse import ServerSentEvent
|
||
|
||
from app.core.config import get_settings
|
||
from app.core.sse import (
|
||
create_error_event,
|
||
create_final_event,
|
||
create_message_event,
|
||
SSEStateMachine,
|
||
)
|
||
from app.models import ChatRequest, ChatResponse
|
||
from app.services.confidence import ConfidenceCalculator, ConfidenceResult
|
||
from app.services.context import ContextMerger, MergedContext
|
||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
|
||
from app.services.memory import MemoryService
|
||
from app.services.retrieval.base import BaseRetriever, RetrievalContext, RetrievalResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
OPTIMIZED_SYSTEM_PROMPT = """你是学校智能客服助手,基于提供的知识库内容回答用户问题。
|
||
|
||
回答要求:
|
||
1. 严格基于提供的知识库内容回答,不要编造信息
|
||
2. 如果知识库中没有相关信息,明确告知用户并建议转人工或稍后重试
|
||
3. 保持专业、友好的语气,回答简洁明了,突出重点
|
||
4. 如果引用知识库内容,请注明来源(如:根据[文档1]...)
|
||
5. 对于时效性问题,请提醒用户注意文档的有效期"""
|
||
|
||
|
||
@dataclass
|
||
class OrchestratorConfig:
|
||
"""
|
||
Configuration for OrchestratorService.
|
||
[AC-AISVC-01] Centralized configuration for orchestration.
|
||
"""
|
||
max_history_tokens: int = 4000
|
||
max_evidence_tokens: int = 2000
|
||
system_prompt: str = OPTIMIZED_SYSTEM_PROMPT
|
||
enable_rag: bool = True
|
||
use_optimized_retriever: bool = True
|
||
|
||
|
||
@dataclass
|
||
class GenerationContext:
|
||
"""
|
||
[AC-AISVC-01, AC-AISVC-02] Context accumulated during generation pipeline.
|
||
Contains all intermediate results for diagnostics and response building.
|
||
"""
|
||
tenant_id: str
|
||
session_id: str
|
||
current_message: str
|
||
channel_type: str
|
||
request_metadata: dict[str, Any] | None = None
|
||
|
||
local_history: list[dict[str, str]] = field(default_factory=list)
|
||
merged_context: MergedContext | None = None
|
||
retrieval_result: RetrievalResult | None = None
|
||
llm_response: LLMResponse | None = None
|
||
confidence_result: ConfidenceResult | None = None
|
||
|
||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class OrchestratorService:
|
||
"""
|
||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
|
||
Coordinates memory, retrieval, and LLM components.
|
||
|
||
SSE Event Flow (per design.md Section 6.2):
|
||
- message* (0 or more) -> final (exactly 1) -> close
|
||
- OR message* (0 or more) -> error (exactly 1) -> close
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_client: LLMClient | None = None,
|
||
memory_service: MemoryService | None = None,
|
||
retriever: BaseRetriever | None = None,
|
||
context_merger: ContextMerger | None = None,
|
||
confidence_calculator: ConfidenceCalculator | None = None,
|
||
config: OrchestratorConfig | None = None,
|
||
):
|
||
"""
|
||
Initialize orchestrator with optional dependencies for DI.
|
||
|
||
Args:
|
||
llm_client: LLM client for generation
|
||
memory_service: Memory service for session history
|
||
retriever: Retriever for RAG
|
||
context_merger: Context merger for history deduplication
|
||
confidence_calculator: Confidence calculator for response scoring
|
||
config: Orchestrator configuration
|
||
"""
|
||
settings = get_settings()
|
||
self._llm_client = llm_client
|
||
self._memory_service = memory_service
|
||
self._retriever = retriever
|
||
self._context_merger = context_merger or ContextMerger(
|
||
max_history_tokens=getattr(settings, "max_history_tokens", 4000)
|
||
)
|
||
self._confidence_calculator = confidence_calculator or ConfidenceCalculator()
|
||
self._config = config or OrchestratorConfig(
|
||
max_history_tokens=getattr(settings, "max_history_tokens", 4000),
|
||
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
|
||
enable_rag=True,
|
||
)
|
||
self._llm_config = LLMConfig(
|
||
model=getattr(settings, "llm_model", "gpt-4o-mini"),
|
||
max_tokens=getattr(settings, "llm_max_tokens", 2048),
|
||
temperature=getattr(settings, "llm_temperature", 0.7),
|
||
timeout_seconds=getattr(settings, "llm_timeout_seconds", 30),
|
||
max_retries=getattr(settings, "llm_max_retries", 3),
|
||
)
|
||
|
||
async def generate(
|
||
self,
|
||
tenant_id: str,
|
||
request: ChatRequest,
|
||
) -> ChatResponse:
|
||
"""
|
||
Generate a non-streaming response.
|
||
[AC-AISVC-01, AC-AISVC-02] Complete generation pipeline.
|
||
|
||
Pipeline (per design.md Section 2.2):
|
||
1. Load local history from Memory
|
||
2. Merge with external history (dedup + truncate)
|
||
3. RAG retrieval (optional)
|
||
4. Build prompt with context and evidence
|
||
5. LLM generation
|
||
6. Calculate confidence
|
||
7. Save messages to Memory
|
||
8. Return ChatResponse
|
||
"""
|
||
logger.info(
|
||
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
|
||
f"session={request.session_id}, channel_type={request.channel_type}, "
|
||
f"current_message={request.current_message[:100]}..."
|
||
)
|
||
logger.info(
|
||
f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, "
|
||
f"use_optimized_retriever={self._config.use_optimized_retriever}, "
|
||
f"llm_client={'configured' if self._llm_client else 'NOT configured'}, "
|
||
f"retriever={'configured' if self._retriever else 'NOT configured'}"
|
||
)
|
||
|
||
ctx = GenerationContext(
|
||
tenant_id=tenant_id,
|
||
session_id=request.session_id,
|
||
current_message=request.current_message,
|
||
channel_type=request.channel_type.value,
|
||
request_metadata=request.metadata,
|
||
)
|
||
|
||
try:
|
||
await self._load_local_history(ctx)
|
||
|
||
await self._merge_context(ctx, request.history)
|
||
|
||
if self._config.enable_rag and self._retriever:
|
||
await self._retrieve_evidence(ctx)
|
||
|
||
await self._generate_response(ctx)
|
||
|
||
self._calculate_confidence(ctx)
|
||
|
||
await self._save_messages(ctx)
|
||
|
||
return self._build_response(ctx)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[AC-AISVC-01] Generation failed: {e}")
|
||
return ChatResponse(
|
||
reply="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
|
||
confidence=0.0,
|
||
should_transfer=True,
|
||
transfer_reason=f"服务异常: {str(e)}",
|
||
metadata={"error": str(e), "diagnostics": ctx.diagnostics},
|
||
)
|
||
|
||
async def _load_local_history(self, ctx: GenerationContext) -> None:
|
||
"""
|
||
[AC-AISVC-13] Load local history from Memory service.
|
||
Step 1 of the generation pipeline.
|
||
"""
|
||
if not self._memory_service:
|
||
logger.info("[AC-AISVC-13] No memory service configured, skipping history load")
|
||
ctx.diagnostics["memory_enabled"] = False
|
||
return
|
||
|
||
try:
|
||
messages = await self._memory_service.load_history(
|
||
tenant_id=ctx.tenant_id,
|
||
session_id=ctx.session_id,
|
||
)
|
||
|
||
ctx.local_history = [
|
||
{"role": msg.role, "content": msg.content}
|
||
for msg in messages
|
||
]
|
||
|
||
ctx.diagnostics["memory_enabled"] = True
|
||
ctx.diagnostics["local_history_count"] = len(ctx.local_history)
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-13] Loaded {len(ctx.local_history)} messages from memory "
|
||
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"[AC-AISVC-13] Failed to load history: {e}")
|
||
ctx.diagnostics["memory_error"] = str(e)
|
||
|
||
async def _merge_context(
|
||
self,
|
||
ctx: GenerationContext,
|
||
external_history: list | None,
|
||
) -> None:
|
||
"""
|
||
[AC-AISVC-14, AC-AISVC-15] Merge local and external history.
|
||
Step 2 of the generation pipeline.
|
||
|
||
Design reference: design.md Section 7
|
||
- Deduplication based on fingerprint
|
||
- Truncation to fit token budget
|
||
"""
|
||
external_messages = None
|
||
if external_history:
|
||
external_messages = [
|
||
{"role": msg.role.value, "content": msg.content}
|
||
for msg in external_history
|
||
]
|
||
|
||
ctx.merged_context = self._context_merger.merge_and_truncate(
|
||
local_history=ctx.local_history,
|
||
external_history=external_messages,
|
||
max_tokens=self._config.max_history_tokens,
|
||
)
|
||
|
||
ctx.diagnostics["merged_context"] = {
|
||
"local_count": ctx.merged_context.local_count,
|
||
"external_count": ctx.merged_context.external_count,
|
||
"duplicates_skipped": ctx.merged_context.duplicates_skipped,
|
||
"truncated_count": ctx.merged_context.truncated_count,
|
||
"total_tokens": ctx.merged_context.total_tokens,
|
||
}
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
|
||
f"local={ctx.merged_context.local_count}, "
|
||
f"external={ctx.merged_context.external_count}, "
|
||
f"tokens={ctx.merged_context.total_tokens}"
|
||
)
|
||
|
||
async def _retrieve_evidence(self, ctx: GenerationContext) -> None:
|
||
"""
|
||
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
|
||
Step 3 of the generation pipeline.
|
||
"""
|
||
logger.info(
|
||
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
|
||
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
|
||
)
|
||
try:
|
||
retrieval_ctx = RetrievalContext(
|
||
tenant_id=ctx.tenant_id,
|
||
query=ctx.current_message,
|
||
session_id=ctx.session_id,
|
||
channel_type=ctx.channel_type,
|
||
metadata=ctx.request_metadata,
|
||
)
|
||
|
||
ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx)
|
||
|
||
ctx.diagnostics["retrieval"] = {
|
||
"hit_count": ctx.retrieval_result.hit_count,
|
||
"max_score": ctx.retrieval_result.max_score,
|
||
"is_empty": ctx.retrieval_result.is_empty,
|
||
}
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
|
||
f"hits={ctx.retrieval_result.hit_count}, "
|
||
f"max_score={ctx.retrieval_result.max_score:.3f}, "
|
||
f"is_empty={ctx.retrieval_result.is_empty}"
|
||
)
|
||
|
||
if ctx.retrieval_result.hit_count > 0:
|
||
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
|
||
logger.info(
|
||
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
|
||
f"text_preview={hit.text[:100]}..."
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
|
||
ctx.retrieval_result = RetrievalResult(
|
||
hits=[],
|
||
diagnostics={"error": str(e)},
|
||
)
|
||
ctx.diagnostics["retrieval_error"] = str(e)
|
||
|
||
async def _generate_response(self, ctx: GenerationContext) -> None:
|
||
"""
|
||
[AC-AISVC-02] Generate response using LLM.
|
||
Step 4-5 of the generation pipeline.
|
||
"""
|
||
messages = self._build_llm_messages(ctx)
|
||
logger.info(
|
||
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
|
||
f"has_retrieval_result={ctx.retrieval_result is not None}, "
|
||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
|
||
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
|
||
)
|
||
|
||
if not self._llm_client:
|
||
logger.warning(
|
||
f"[AC-AISVC-02] No LLM client configured, using fallback. "
|
||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
|
||
)
|
||
ctx.llm_response = LLMResponse(
|
||
content=self._fallback_response(ctx),
|
||
model="fallback",
|
||
usage={},
|
||
finish_reason="fallback",
|
||
)
|
||
ctx.diagnostics["llm_mode"] = "fallback"
|
||
ctx.diagnostics["fallback_reason"] = "no_llm_client"
|
||
return
|
||
|
||
try:
|
||
ctx.llm_response = await self._llm_client.generate(
|
||
messages=messages,
|
||
config=self._llm_config,
|
||
)
|
||
ctx.diagnostics["llm_mode"] = "live"
|
||
ctx.diagnostics["llm_model"] = ctx.llm_response.model
|
||
ctx.diagnostics["llm_usage"] = ctx.llm_response.usage
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-02] LLM response generated: "
|
||
f"model={ctx.llm_response.model}, "
|
||
f"tokens={ctx.llm_response.usage}, "
|
||
f"content_preview={ctx.llm_response.content[:100]}..."
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[AC-AISVC-02] LLM generation failed: {e}, "
|
||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
|
||
exc_info=True
|
||
)
|
||
ctx.llm_response = LLMResponse(
|
||
content=self._fallback_response(ctx),
|
||
model="fallback",
|
||
usage={},
|
||
finish_reason="error",
|
||
metadata={"error": str(e)},
|
||
)
|
||
ctx.diagnostics["llm_error"] = str(e)
|
||
ctx.diagnostics["llm_mode"] = "fallback"
|
||
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
|
||
|
||
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
|
||
"""
|
||
[AC-AISVC-02] Build messages for LLM including system prompt and evidence.
|
||
"""
|
||
messages = []
|
||
|
||
system_content = self._config.system_prompt
|
||
|
||
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
|
||
evidence_text = self._format_evidence(ctx.retrieval_result)
|
||
system_content += f"\n\n知识库参考内容:\n{evidence_text}"
|
||
|
||
messages.append({"role": "system", "content": system_content})
|
||
|
||
if ctx.merged_context and ctx.merged_context.messages:
|
||
messages.extend(ctx.merged_context.messages)
|
||
|
||
messages.append({"role": "user", "content": ctx.current_message})
|
||
|
||
return messages
|
||
|
||
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
|
||
"""
|
||
[AC-AISVC-17] Format retrieval hits as evidence text.
|
||
Optimized format with source attribution and metadata.
|
||
"""
|
||
evidence_parts = []
|
||
for i, hit in enumerate(retrieval_result.hits[:5], 1):
|
||
metadata = hit.metadata or {}
|
||
source = metadata.get("metadata", {}).get("source_doc", "知识库")
|
||
category = metadata.get("metadata", {}).get("category", "")
|
||
department = metadata.get("metadata", {}).get("department", "")
|
||
|
||
header = f"[文档{i}]"
|
||
if source and source != "知识库":
|
||
header += f" 来源:{source}"
|
||
if category:
|
||
header += f" | 类别:{category}"
|
||
if department:
|
||
header += f" | 部门:{department}"
|
||
|
||
evidence_parts.append(f"{header}\n相关度:{hit.score:.2f}\n内容:{hit.text}")
|
||
|
||
return "\n\n".join(evidence_parts)
|
||
|
||
def _fallback_response(self, ctx: GenerationContext) -> str:
|
||
"""
|
||
[AC-AISVC-17] Generate fallback response when LLM is unavailable.
|
||
"""
|
||
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
|
||
return (
|
||
"根据知识库信息,我找到了一些相关内容,"
|
||
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
|
||
)
|
||
return (
|
||
"抱歉,我暂时无法处理您的请求。"
|
||
"请稍后重试或联系人工客服获取帮助。"
|
||
)
|
||
|
||
def _calculate_confidence(self, ctx: GenerationContext) -> None:
|
||
"""
|
||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score.
|
||
Step 6 of the generation pipeline.
|
||
"""
|
||
if ctx.retrieval_result:
|
||
evidence_tokens = 0
|
||
if not ctx.retrieval_result.is_empty:
|
||
evidence_tokens = sum(
|
||
len(hit.text.split()) * 2
|
||
for hit in ctx.retrieval_result.hits
|
||
)
|
||
|
||
ctx.confidence_result = self._confidence_calculator.calculate_confidence(
|
||
retrieval_result=ctx.retrieval_result,
|
||
evidence_tokens=evidence_tokens,
|
||
)
|
||
else:
|
||
ctx.confidence_result = self._confidence_calculator.calculate_confidence_no_retrieval()
|
||
|
||
ctx.diagnostics["confidence"] = {
|
||
"score": ctx.confidence_result.confidence,
|
||
"should_transfer": ctx.confidence_result.should_transfer,
|
||
"is_insufficient": ctx.confidence_result.is_retrieval_insufficient,
|
||
}
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
|
||
f"{ctx.confidence_result.confidence:.3f}, "
|
||
f"should_transfer={ctx.confidence_result.should_transfer}"
|
||
)
|
||
|
||
async def _save_messages(self, ctx: GenerationContext) -> None:
|
||
"""
|
||
[AC-AISVC-13] Save user and assistant messages to Memory.
|
||
Step 7 of the generation pipeline.
|
||
"""
|
||
if not self._memory_service:
|
||
logger.info("[AC-AISVC-13] No memory service configured, skipping save")
|
||
return
|
||
|
||
try:
|
||
await self._memory_service.get_or_create_session(
|
||
tenant_id=ctx.tenant_id,
|
||
session_id=ctx.session_id,
|
||
channel_type=ctx.channel_type,
|
||
metadata=ctx.request_metadata,
|
||
)
|
||
|
||
messages_to_save = [
|
||
{"role": "user", "content": ctx.current_message},
|
||
]
|
||
|
||
if ctx.llm_response:
|
||
messages_to_save.append({
|
||
"role": "assistant",
|
||
"content": ctx.llm_response.content,
|
||
})
|
||
|
||
await self._memory_service.append_messages(
|
||
tenant_id=ctx.tenant_id,
|
||
session_id=ctx.session_id,
|
||
messages=messages_to_save,
|
||
)
|
||
|
||
ctx.diagnostics["messages_saved"] = len(messages_to_save)
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-13] Saved {len(messages_to_save)} messages "
|
||
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"[AC-AISVC-13] Failed to save messages: {e}")
|
||
ctx.diagnostics["save_error"] = str(e)
|
||
|
||
def _build_response(self, ctx: GenerationContext) -> ChatResponse:
|
||
"""
|
||
[AC-AISVC-02] Build final ChatResponse from generation context.
|
||
Step 8 of the generation pipeline.
|
||
"""
|
||
reply = ctx.llm_response.content if ctx.llm_response else self._fallback_response(ctx)
|
||
|
||
confidence = ctx.confidence_result.confidence if ctx.confidence_result else 0.5
|
||
should_transfer = ctx.confidence_result.should_transfer if ctx.confidence_result else True
|
||
transfer_reason = ctx.confidence_result.transfer_reason if ctx.confidence_result else None
|
||
|
||
response_metadata = {
|
||
"session_id": ctx.session_id,
|
||
"channel_type": ctx.channel_type,
|
||
"diagnostics": ctx.diagnostics,
|
||
}
|
||
|
||
return ChatResponse(
|
||
reply=reply,
|
||
confidence=confidence,
|
||
should_transfer=should_transfer,
|
||
transfer_reason=transfer_reason,
|
||
metadata=response_metadata,
|
||
)
|
||
|
||
async def generate_stream(
|
||
self,
|
||
tenant_id: str,
|
||
request: ChatRequest,
|
||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||
"""
|
||
Generate a streaming response.
|
||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
||
|
||
SSE Event Sequence (per design.md Section 6.2):
|
||
1. message events (multiple) - each with incremental delta
|
||
2. final event (exactly 1) - with complete response
|
||
3. connection close
|
||
|
||
OR on error:
|
||
1. message events (0 or more)
|
||
2. error event (exactly 1)
|
||
3. connection close
|
||
"""
|
||
logger.info(
|
||
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
||
f"session={request.session_id}"
|
||
)
|
||
|
||
state_machine = SSEStateMachine()
|
||
await state_machine.transition_to_streaming()
|
||
|
||
ctx = GenerationContext(
|
||
tenant_id=tenant_id,
|
||
session_id=request.session_id,
|
||
current_message=request.current_message,
|
||
channel_type=request.channel_type.value,
|
||
request_metadata=request.metadata,
|
||
)
|
||
|
||
try:
|
||
await self._load_local_history(ctx)
|
||
await self._merge_context(ctx, request.history)
|
||
|
||
if self._config.enable_rag and self._retriever:
|
||
await self._retrieve_evidence(ctx)
|
||
|
||
full_reply = ""
|
||
|
||
if self._llm_client:
|
||
async for event in self._stream_from_llm(ctx, state_machine):
|
||
if event.event == "message":
|
||
full_reply += self._extract_delta_from_event(event)
|
||
yield event
|
||
else:
|
||
async for event in self._stream_mock_response(ctx, state_machine):
|
||
if event.event == "message":
|
||
full_reply += self._extract_delta_from_event(event)
|
||
yield event
|
||
|
||
if ctx.llm_response is None:
|
||
ctx.llm_response = LLMResponse(
|
||
content=full_reply,
|
||
model="streaming",
|
||
usage={},
|
||
finish_reason="stop",
|
||
)
|
||
|
||
self._calculate_confidence(ctx)
|
||
|
||
await self._save_messages(ctx)
|
||
|
||
if await state_machine.transition_to_final():
|
||
yield create_final_event(
|
||
reply=full_reply,
|
||
confidence=ctx.confidence_result.confidence if ctx.confidence_result else 0.5,
|
||
should_transfer=ctx.confidence_result.should_transfer if ctx.confidence_result else False,
|
||
transfer_reason=ctx.confidence_result.transfer_reason if ctx.confidence_result else None,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
||
if await state_machine.transition_to_error():
|
||
yield create_error_event(
|
||
code="GENERATION_ERROR",
|
||
message=str(e),
|
||
)
|
||
finally:
|
||
await state_machine.close()
|
||
|
||
async def _stream_from_llm(
|
||
self,
|
||
ctx: GenerationContext,
|
||
state_machine: SSEStateMachine,
|
||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||
"""
|
||
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
|
||
"""
|
||
messages = self._build_llm_messages(ctx)
|
||
|
||
async for chunk in self._llm_client.stream_generate(messages, self._llm_config):
|
||
if not state_machine.can_send_message():
|
||
break
|
||
|
||
if chunk.delta:
|
||
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
|
||
yield create_message_event(delta=chunk.delta)
|
||
|
||
if chunk.finish_reason:
|
||
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
|
||
break
|
||
|
||
async def _stream_mock_response(
|
||
self,
|
||
ctx: GenerationContext,
|
||
state_machine: SSEStateMachine,
|
||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||
"""
|
||
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
|
||
Simulates LLM-style incremental output.
|
||
"""
|
||
import asyncio
|
||
|
||
reply_parts = ["收到", "您的", "消息:", f" {ctx.current_message}"]
|
||
|
||
for part in reply_parts:
|
||
if not state_machine.can_send_message():
|
||
break
|
||
|
||
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
|
||
yield create_message_event(delta=part)
|
||
await asyncio.sleep(0.05)
|
||
|
||
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
|
||
"""Extract delta content from a message event."""
|
||
import json
|
||
try:
|
||
if event.data:
|
||
data = json.loads(event.data)
|
||
return data.get("delta", "")
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
return ""
|
||
|
||
|
||
_orchestrator_service: OrchestratorService | None = None
|
||
|
||
|
||
def get_orchestrator_service() -> OrchestratorService:
|
||
"""Get or create orchestrator service instance."""
|
||
global _orchestrator_service
|
||
if _orchestrator_service is None:
|
||
_orchestrator_service = OrchestratorService()
|
||
return _orchestrator_service
|
||
|
||
|
||
def set_orchestrator_service(service: OrchestratorService) -> None:
|
||
"""Set orchestrator service instance for testing."""
|
||
global _orchestrator_service
|
||
_orchestrator_service = service
|