ai-robot-core/ai-service/app/services/orchestrator.py

1339 lines
57 KiB
Python

"""
Orchestrator service for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
Design reference: design.md Section 10 - Orchestrator升级为12步pipeline
1. InputScanner: Scan user input for forbidden words (logging only)
2. FlowEngine: Check if session has active script flow
3. IntentRouter: Match intent rules and route to appropriate handler
4. QueryRewriter: (Optional, skipped in MVP) Rewrite query for better retrieval
5. Multi-KB Retrieval: Retrieve from target knowledge bases
6. ResultRanker: Rank results by KB type priority
7. PromptBuilder: Load template + inject behavior rules
8. LLM.generate: Generate response
9. OutputFilter: Filter forbidden words in output
10. Confidence: Calculate confidence score
11. Memory: Save messages
12. Return: Build and return ChatResponse
RAG Optimization (rag-optimization/spec.md):
- Two-stage retrieval with Matryoshka dimensions
- RRF hybrid ranking
- Optimized prompt engineering
"""
import logging
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
from sse_starlette.sse import ServerSentEvent
from app.core.config import get_settings
from app.core.database import get_session
from app.core.prompts import SYSTEM_PROMPT, format_evidence_for_prompt
from app.core.sse import (
SSEStateMachine,
create_error_event,
create_final_event,
create_message_event,
)
from app.models import ChatRequest, ChatResponse
from app.services.confidence import ConfidenceCalculator, ConfidenceResult
from app.services.context import ContextMerger, MergedContext
from app.services.flow.engine import FlowEngine
from app.services.guardrail.behavior_service import BehaviorRuleService
from app.services.guardrail.input_scanner import InputScanner
from app.services.guardrail.output_filter import OutputFilter
from app.services.guardrail.word_service import ForbiddenWordService
from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
from app.services.memory import MemoryService
from app.services.prompt.template_service import PromptTemplateService
from app.services.prompt.variable_resolver import VariableResolver
from app.services.retrieval.base import BaseRetriever, RetrievalContext, RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class OrchestratorConfig:
"""
Configuration for OrchestratorService.
[AC-AISVC-01] Centralized configuration for orchestration.
"""
max_history_tokens: int = 4000
max_evidence_tokens: int = 2000
system_prompt: str = SYSTEM_PROMPT
enable_rag: bool = True
use_optimized_retriever: bool = True
@dataclass
class GenerationContext:
"""
[AC-AISVC-01, AC-AISVC-02] Context accumulated during generation pipeline.
Contains all intermediate results for diagnostics and response building.
12-Step Pipeline tracking:
1. input_scan_result: InputScanner result
2. active_flow: Active FlowInstance if exists
3. intent_match: IntentMatchResult if matched
4. query_rewritten: Rewritten query (optional)
5. retrieval_result: Multi-KB retrieval result
6. ranked_results: Ranked retrieval results
7. system_prompt: Built system prompt with template + behavior rules
8. llm_response: LLM generation result
9. filtered_reply: Output after forbidden word filtering
10. confidence_result: Confidence calculation result
11. messages_saved: Whether messages were saved
12. final_response: Final ChatResponse
"""
tenant_id: str
session_id: str
current_message: str
channel_type: str
request_metadata: dict[str, Any] | None = None
# Original pipeline fields
local_history: list[dict[str, str]] = field(default_factory=list)
merged_context: MergedContext | None = None
retrieval_result: RetrievalResult | None = None
llm_response: LLMResponse | None = None
confidence_result: ConfidenceResult | None = None
# Phase 10-14 new pipeline fields
input_scan_result: Any = None # InputScanResult
active_flow: Any = None # FlowInstance
intent_match: Any = None # IntentMatchResult
query_rewritten: str | None = None
ranked_results: list[Any] = field(default_factory=list)
system_prompt: str | None = None
filtered_reply: str | None = None
target_kb_ids: list[str] | None = None
behavior_rules: list[str] = field(default_factory=list)
diagnostics: dict[str, Any] = field(default_factory=dict)
execution_steps: list[dict[str, Any]] = field(default_factory=list)
class OrchestratorService:
"""
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
Coordinates memory, retrieval, LLM, and guardrail components.
12-Step Pipeline (design.md Section 10):
1. InputScanner: Scan user input for forbidden words
2. FlowEngine: Check if session has active script flow
3. IntentRouter: Match intent rules and route
4. QueryRewriter: (Optional, skipped in MVP)
5. Multi-KB Retrieval: Retrieve from target knowledge bases
6. ResultRanker: Rank results by KB type priority
7. PromptBuilder: Load template + inject behavior rules
8. LLM.generate: Generate response
9. OutputFilter: Filter forbidden words in output
10. Confidence: Calculate confidence score
11. Memory: Save messages
12. Return: Build and return ChatResponse
SSE Event Flow (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
"""
def __init__(
self,
llm_client: LLMClient | None = None,
memory_service: MemoryService | None = None,
retriever: BaseRetriever | None = None,
context_merger: ContextMerger | None = None,
confidence_calculator: ConfidenceCalculator | None = None,
config: OrchestratorConfig | None = None,
# Phase 10-14 new services
input_scanner: InputScanner | None = None,
intent_router: IntentRouter | None = None,
intent_rule_service: IntentRuleService | None = None,
flow_engine: FlowEngine | None = None,
prompt_template_service: PromptTemplateService | None = None,
variable_resolver: VariableResolver | None = None,
behavior_rule_service: BehaviorRuleService | None = None,
output_filter: OutputFilter | None = None,
):
"""
Initialize orchestrator with optional dependencies for DI.
Args:
llm_client: LLM client for generation
memory_service: Memory service for session history
retriever: Retriever for RAG
context_merger: Context merger for history deduplication
confidence_calculator: Confidence calculator for response scoring
config: Orchestrator configuration
input_scanner: Input scanner for forbidden word detection
intent_router: Intent router for rule matching
intent_rule_service: Intent rule service for loading rules
flow_engine: Flow engine for script flow execution
prompt_template_service: Prompt template service for template loading
variable_resolver: Variable resolver for template variable substitution
behavior_rule_service: Behavior rule service for loading behavior rules
output_filter: Output filter for forbidden word filtering
"""
settings = get_settings()
self._llm_client = llm_client
self._memory_service = memory_service
self._retriever = retriever
self._context_merger = context_merger or ContextMerger(
max_history_tokens=getattr(settings, "max_history_tokens", 4000)
)
self._confidence_calculator = confidence_calculator or ConfidenceCalculator()
self._config = config or OrchestratorConfig(
max_history_tokens=getattr(settings, "max_history_tokens", 4000),
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
enable_rag=True,
)
self._llm_config: LLMConfig | None = None
# Phase 10-14 services
self._input_scanner = input_scanner
self._intent_router = intent_router or IntentRouter()
self._intent_rule_service = intent_rule_service
self._flow_engine = flow_engine
self._prompt_template_service = prompt_template_service
self._variable_resolver = variable_resolver or VariableResolver()
self._behavior_rule_service = behavior_rule_service
self._output_filter = output_filter
def _record_step(
self,
ctx: GenerationContext,
step_no: int,
name: str,
status: str = "success",
duration_ms: int = 0,
input_data: Any = None,
output_data: Any = None,
error: str | None = None,
) -> None:
"""Record execution step for flow test visualization."""
ctx.execution_steps.append({
"step": step_no,
"name": name,
"status": status,
"duration_ms": duration_ms,
"input": input_data,
"output": output_data,
"error": error,
})
async def generate(
self,
tenant_id: str,
request: ChatRequest,
) -> ChatResponse:
"""
Generate a non-streaming response.
[AC-AISVC-01, AC-AISVC-02] Complete 12-step generation pipeline.
12-Step Pipeline (design.md Section 10):
1. InputScanner: Scan user input for forbidden words
2. FlowEngine: Check if session has active script flow
3. IntentRouter: Match intent rules and route
4. QueryRewriter: (Optional, skipped in MVP)
5. Multi-KB Retrieval: Retrieve from target knowledge bases
6. ResultRanker: Rank results by KB type priority
7. PromptBuilder: Load template + inject behavior rules
8. LLM.generate: Generate response
9. OutputFilter: Filter forbidden words in output
10. Confidence: Calculate confidence score
11. Memory: Save messages
12. Return: Build and return ChatResponse
"""
logger.info(
f"[AC-AISVC-01] Starting 12-step generation for tenant={tenant_id}, "
f"session={request.session_id}, channel_type={request.channel_type}, "
f"current_message={request.current_message[:100]}..."
)
ctx = GenerationContext(
tenant_id=tenant_id,
session_id=request.session_id,
current_message=request.current_message,
channel_type=request.channel_type.value,
request_metadata=request.metadata,
)
try:
import time
# Step 1: InputScanner - Scan user input for forbidden words
step_start = time.time()
await self._scan_input(ctx)
self._record_step(ctx, 1, "InputScanner", "success", int((time.time() - step_start) * 1000),
input_data={"text": ctx.current_message[:200]},
output_data=ctx.diagnostics.get("input_scan"))
# Load local history and merge context (original pipeline)
await self._load_local_history(ctx)
await self._merge_context(ctx, request.history)
# Step 2: FlowEngine - Check if session has active script flow
step_start = time.time()
await self._check_active_flow(ctx)
self._record_step(ctx, 2, "FlowEngine", "success", int((time.time() - step_start) * 1000),
input_data={"session_id": ctx.session_id},
output_data={"active_flow": bool(ctx.active_flow), "flow_name": getattr(ctx.active_flow, 'flow_name', None) if ctx.active_flow else None})
# Step 3: IntentRouter - Match intent rules and route
step_start = time.time()
await self._match_intent(ctx)
intent_output = {"matched": bool(ctx.intent_match)}
if ctx.intent_match:
intent_output["rule_name"] = getattr(ctx.intent_match, 'rule_name', None)
intent_output["confidence"] = getattr(ctx.intent_match, 'confidence', None)
self._record_step(ctx, 3, "IntentRouter", "success", int((time.time() - step_start) * 1000),
input_data={"query": ctx.current_message[:100]},
output_data=intent_output)
# Step 4: QueryRewriter - (Optional, skipped in MVP)
self._record_step(ctx, 4, "QueryRewriter", "skipped", 0,
input_data={"query": ctx.current_message[:100]},
output_data={"note": "Skipped in MVP", "rewritten": None})
# Step 5-6: Multi-KB Retrieval + ResultRanker
step_start = time.time()
if self._config.enable_rag and self._retriever:
await self._retrieve_evidence(ctx)
retrieval_output = {
"hit_count": len(ctx.retrieval_result.hits) if ctx.retrieval_result else 0,
"max_score": ctx.retrieval_result.max_score if ctx.retrieval_result else 0,
}
if ctx.retrieval_result and ctx.retrieval_result.hits:
retrieval_output["top_hits"] = [
{
"content": hit.text[:200] + "..." if len(hit.text) > 200 else hit.text,
"score": round(hit.score, 4),
"source": hit.source,
}
for hit in ctx.retrieval_result.hits[:5]
]
self._record_step(ctx, 5, "MultiKBRetrieval", "success", int((time.time() - step_start) * 1000),
input_data={"query": ctx.current_message[:100], "top_k": 3},
output_data=retrieval_output)
else:
self._record_step(ctx, 5, "MultiKBRetrieval", "skipped", 0,
input_data={"query": ctx.current_message[:100]},
output_data={"note": "RAG disabled or no retriever"})
# Step 7: PromptBuilder - Load template + inject behavior rules
step_start = time.time()
await self._build_system_prompt(ctx)
self._record_step(ctx, 7, "PromptBuilder", "success", int((time.time() - step_start) * 1000),
input_data={"template_id": getattr(ctx, 'template_id', None), "behavior_rules": ctx.behavior_rules[:3] if ctx.behavior_rules else []},
output_data={"prompt_length": len(ctx.system_prompt) if ctx.system_prompt else 0, "prompt_preview": ctx.system_prompt[:300] + "..." if ctx.system_prompt and len(ctx.system_prompt) > 300 else ctx.system_prompt})
# Step 8: LLM.generate - Generate response
step_start = time.time()
await self._generate_response(ctx)
llm_model = ctx.llm_response.model if ctx.llm_response else "unknown"
self._record_step(ctx, 8, "LLMGenerate", "success", int((time.time() - step_start) * 1000),
input_data={"model": llm_model, "messages_count": len(self._build_llm_messages(ctx)) if hasattr(self, '_build_llm_messages') else 1},
output_data={"reply_length": len(ctx.llm_response.content) if ctx.llm_response else 0, "reply_preview": ctx.llm_response.content[:200] + "..." if ctx.llm_response and len(ctx.llm_response.content) > 200 else (ctx.llm_response.content if ctx.llm_response else None)})
# Step 9: OutputFilter - Filter forbidden words in output
step_start = time.time()
await self._filter_output(ctx)
filter_output = {"filtered": ctx.filtered_reply != ctx.llm_response.content if ctx.llm_response else False}
if ctx.diagnostics.get("output_filter"):
filter_output["triggered_words"] = ctx.diagnostics.get("output_filter", {}).get("triggered_words", [])
self._record_step(ctx, 9, "OutputFilter", "success", int((time.time() - step_start) * 1000),
input_data={"text_length": len(ctx.llm_response.content) if ctx.llm_response else 0},
output_data=filter_output)
# Step 10: Confidence - Calculate confidence score
step_start = time.time()
self._calculate_confidence(ctx)
self._record_step(ctx, 10, "Confidence", "success", int((time.time() - step_start) * 1000),
input_data={"reply_length": len(ctx.filtered_reply) if ctx.filtered_reply else 0, "hit_count": len(ctx.retrieval_result.hits) if ctx.retrieval_result else 0},
output_data={"confidence": ctx.confidence_result.confidence if ctx.confidence_result else 0, "should_transfer": ctx.confidence_result.should_transfer if ctx.confidence_result else True})
# Step 11: Memory - Save messages
step_start = time.time()
await self._save_messages(ctx)
self._record_step(ctx, 11, "Memory", "success", int((time.time() - step_start) * 1000),
input_data={"session_id": ctx.session_id},
output_data={"saved": True})
# Step 12: Return - Build and return ChatResponse
self._record_step(ctx, 12, "Response", "success", 0,
input_data={"confidence": ctx.confidence_result.confidence if ctx.confidence_result else 0, "should_transfer": ctx.confidence_result.should_transfer if ctx.confidence_result else True},
output_data={"reply_length": len(ctx.filtered_reply) if ctx.filtered_reply else 0, "reply_preview": ctx.filtered_reply[:200] + "..." if ctx.filtered_reply and len(ctx.filtered_reply) > 200 else ctx.filtered_reply})
return self._build_response(ctx)
except Exception as e:
logger.error(f"[AC-AISVC-01] Generation failed: {e}")
return ChatResponse(
reply="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
confidence=0.0,
should_transfer=True,
transfer_reason=f"服务异常: {str(e)}",
metadata={"error": str(e), "diagnostics": ctx.diagnostics},
)
async def _load_local_history(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-13] Load local history from Memory service.
Step 1 of the generation pipeline.
"""
if not self._memory_service:
logger.info("[AC-AISVC-13] No memory service configured, skipping history load")
ctx.diagnostics["memory_enabled"] = False
return
try:
messages = await self._memory_service.load_history(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
)
ctx.local_history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
ctx.diagnostics["memory_enabled"] = True
ctx.diagnostics["local_history_count"] = len(ctx.local_history)
logger.info(
f"[AC-AISVC-13] Loaded {len(ctx.local_history)} messages from memory "
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-13] Failed to load history: {e}")
ctx.diagnostics["memory_error"] = str(e)
async def _scan_input(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-83] Step 1: Scan user input for forbidden words (logging only).
"""
if not self._input_scanner:
logger.debug("[AC-AISVC-83] No input scanner configured, skipping")
ctx.diagnostics["input_scan_enabled"] = False
return
try:
ctx.input_scan_result = await self._input_scanner.scan(
text=ctx.current_message,
tenant_id=ctx.tenant_id,
)
ctx.diagnostics["input_scan"] = {
"flagged": ctx.input_scan_result.flagged,
"matched_words": ctx.input_scan_result.matched_words,
"matched_categories": ctx.input_scan_result.matched_categories,
}
if ctx.input_scan_result.flagged:
logger.info(
f"[AC-AISVC-83] Input flagged: words={ctx.input_scan_result.matched_words}, "
f"categories={ctx.input_scan_result.matched_categories}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-83] Input scan failed: {e}")
ctx.diagnostics["input_scan_error"] = str(e)
async def _check_active_flow(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-75] Step 2: Check if session has active script flow.
If active flow exists, advance it based on user input.
"""
if not self._flow_engine:
logger.debug("[AC-AISVC-75] No flow engine configured, skipping")
ctx.diagnostics["flow_check_enabled"] = False
return
try:
ctx.active_flow = await self._flow_engine.check_active_flow(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
)
if ctx.active_flow:
logger.info(
f"[AC-AISVC-75] Active flow found: flow_id={ctx.active_flow.flow_id}, "
f"current_step={ctx.active_flow.current_step}"
)
# Advance the flow based on user input
advance_result = await self._flow_engine.advance(
instance=ctx.active_flow,
user_input=ctx.current_message,
)
ctx.diagnostics["flow_advance"] = {
"completed": advance_result.completed,
"has_message": advance_result.message is not None,
}
# If flow provides a message, use it as the reply and skip LLM
if advance_result.message:
ctx.llm_response = LLMResponse(
content=advance_result.message,
model="script_flow",
usage={},
finish_reason="flow_step",
)
ctx.diagnostics["flow_handled"] = True
logger.info(f"[AC-AISVC-75] Flow provided reply, skipping LLM")
else:
ctx.diagnostics["flow_check_enabled"] = True
ctx.diagnostics["active_flow"] = False
except Exception as e:
logger.warning(f"[AC-AISVC-75] Flow check failed: {e}")
ctx.diagnostics["flow_check_error"] = str(e)
async def _match_intent(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
"""
# Skip if flow already handled the request
if ctx.diagnostics.get("flow_handled"):
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
return
if not self._intent_rule_service:
logger.debug("[AC-AISVC-69] No intent rule service configured, skipping")
ctx.diagnostics["intent_match_enabled"] = False
return
try:
# Load enabled rules ordered by priority
async with get_session() as session:
from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session)
rules = await rule_service.get_enabled_rules_for_matching(ctx.tenant_id)
if not rules:
ctx.diagnostics["intent_match_enabled"] = True
ctx.diagnostics["intent_matched"] = False
return
# Match intent
ctx.intent_match = self._intent_router.match(
message=ctx.current_message,
rules=rules,
)
if ctx.intent_match:
logger.info(
f"[AC-AISVC-69] Intent matched: rule={ctx.intent_match.rule.name}, "
f"response_type={ctx.intent_match.rule.response_type}"
)
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
# Increment hit count
async with get_session() as session:
rule_service = IntentRuleService(session)
await rule_service.increment_hit_count(
tenant_id=ctx.tenant_id,
rule_id=ctx.intent_match.rule.id,
)
# Route based on response_type
if ctx.intent_match.rule.response_type == "fixed":
# Fixed reply - skip LLM
ctx.llm_response = LLMResponse(
content=ctx.intent_match.rule.fixed_reply or "收到您的消息。",
model="intent_fixed",
usage={},
finish_reason="intent_fixed",
)
ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
elif ctx.intent_match.rule.response_type == "rag":
# RAG with target KBs
ctx.target_kb_ids = ctx.intent_match.rule.target_kb_ids or []
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
elif ctx.intent_match.rule.response_type == "flow":
# Start script flow
if ctx.intent_match.rule.flow_id and self._flow_engine:
async with get_session() as session:
flow_engine = FlowEngine(session)
instance, first_step = await flow_engine.start(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
flow_id=ctx.intent_match.rule.flow_id,
)
if first_step:
ctx.llm_response = LLMResponse(
content=first_step,
model="script_flow",
usage={},
finish_reason="flow_start",
)
ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
elif ctx.intent_match.rule.response_type == "transfer":
# Transfer to human
ctx.llm_response = LLMResponse(
content=ctx.intent_match.rule.transfer_message or "正在为您转接人工客服...",
model="intent_transfer",
usage={},
finish_reason="intent_transfer",
)
ctx.confidence_result = ConfidenceResult(
confidence=0.0,
should_transfer=True,
transfer_reason="intent_rule_transfer",
is_retrieval_insufficient=False,
)
ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
else:
ctx.diagnostics["intent_match_enabled"] = True
ctx.diagnostics["intent_matched"] = False
except Exception as e:
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
ctx.diagnostics["intent_match_error"] = str(e)
async def _merge_context(
self,
ctx: GenerationContext,
external_history: list | None,
) -> None:
"""
[AC-AISVC-14, AC-AISVC-15] Merge local and external history.
Step 2 of the generation pipeline.
Design reference: design.md Section 7
- Deduplication based on fingerprint
- Truncation to fit token budget
"""
external_messages = None
if external_history:
external_messages = [
{"role": msg.role.value, "content": msg.content}
for msg in external_history
]
ctx.merged_context = self._context_merger.merge_and_truncate(
local_history=ctx.local_history,
external_history=external_messages,
max_tokens=self._config.max_history_tokens,
)
ctx.diagnostics["merged_context"] = {
"local_count": ctx.merged_context.local_count,
"external_count": ctx.merged_context.external_count,
"duplicates_skipped": ctx.merged_context.duplicates_skipped,
"truncated_count": ctx.merged_context.truncated_count,
"total_tokens": ctx.merged_context.total_tokens,
}
logger.info(
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
f"local={ctx.merged_context.local_count}, "
f"external={ctx.merged_context.external_count}, "
f"tokens={ctx.merged_context.total_tokens}"
)
async def _retrieve_evidence(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
Step 5-6: Multi-KB retrieval with target KBs from intent matching.
[AC-IDSMETA-19] Inject metadata filters (grade/subject/scene) from context.
"""
# Skip if flow or intent already handled
if ctx.diagnostics.get("flow_handled") or ctx.diagnostics.get("intent_handled"):
logger.info("[AC-AISVC-16] Request already handled, skipping retrieval")
return
logger.info(
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
)
try:
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.current_message,
session_id=ctx.session_id,
channel_type=ctx.channel_type,
metadata=ctx.request_metadata,
)
# If intent matched with target KBs, pass them to retriever
if ctx.target_kb_ids:
retrieval_ctx.metadata = retrieval_ctx.metadata or {}
retrieval_ctx.metadata["target_kb_ids"] = ctx.target_kb_ids
logger.info(f"[AC-AISVC-16] Using target_kb_ids from intent: {ctx.target_kb_ids}")
# [AC-IDSMETA-19] Inject metadata filters from context
metadata_filters = await self._build_metadata_filters(ctx)
if metadata_filters:
retrieval_ctx.tag_filter = metadata_filters
logger.info(
f"[AC-IDSMETA-19] Injected metadata filters: "
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
f"target_kbs={ctx.target_kb_ids}, "
f"applied_metadata_filters={metadata_filters.fields}"
)
ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx)
ctx.diagnostics["retrieval"] = {
"hit_count": ctx.retrieval_result.hit_count,
"max_score": ctx.retrieval_result.max_score,
"is_empty": ctx.retrieval_result.is_empty,
"applied_metadata_filters": metadata_filters.fields if metadata_filters else None,
}
logger.info(
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
f"hits={ctx.retrieval_result.hit_count}, "
f"max_score={ctx.retrieval_result.max_score:.3f}, "
f"is_empty={ctx.retrieval_result.is_empty}"
)
if ctx.retrieval_result.hit_count > 0:
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
logger.info(
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
f"text_preview={hit.text[:100]}..."
)
except Exception as e:
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
ctx.retrieval_result = RetrievalResult(
hits=[],
diagnostics={"error": str(e)},
)
ctx.diagnostics["retrieval_error"] = str(e)
async def _build_metadata_filters(self, ctx: GenerationContext):
"""
[AC-IDSMETA-19] Build metadata filters from context.
Sources:
1. Intent rule metadata (if matched)
2. Session metadata
3. Request metadata
4. Extracted slots from conversation
Returns:
TagFilter with at least grade, subject, scene if available
"""
from app.services.retrieval.metadata import TagFilter
filter_fields = {}
# 1. From intent rule metadata
if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_:
intent_metadata = ctx.intent_match.rule.metadata_
for key in ['grade', 'subject', 'scene']:
if key in intent_metadata:
filter_fields[key] = intent_metadata[key]
# 2. From session/request metadata
if ctx.request_metadata:
for key in ['grade', 'subject', 'scene']:
if key in ctx.request_metadata and key not in filter_fields:
filter_fields[key] = ctx.request_metadata[key]
# 3. From merged context (extracted slots)
if ctx.merged_context and hasattr(ctx.merged_context, 'slots'):
slots = ctx.merged_context.slots or {}
for key in ['grade', 'subject', 'scene']:
if key in slots and key not in filter_fields:
filter_fields[key] = slots[key]
if not filter_fields:
return None
return TagFilter(fields=filter_fields)
async def _build_system_prompt(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-56, AC-AISVC-84] Step 7: Build system prompt with template + behavior rules.
"""
# Skip if flow or intent already handled
if ctx.diagnostics.get("flow_handled") or ctx.diagnostics.get("intent_handled"):
logger.info("[AC-AISVC-56] Request already handled, using default prompt")
ctx.system_prompt = self._config.system_prompt
return
try:
# Try to load template from service
if self._prompt_template_service:
async with get_session() as session:
template_service = PromptTemplateService(session)
template_version = await template_service.get_published_template(
tenant_id=ctx.tenant_id,
scene="default", # TODO: Make scene configurable
)
if template_version:
# Resolve variables
variables = {
"persona_name": "AI助手",
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"channel_type": ctx.channel_type,
}
ctx.system_prompt = self._variable_resolver.resolve(
template=template_version.system_instruction,
variables=variables,
)
logger.info(f"[AC-AISVC-56] Loaded template: scene=default, version={template_version.version}")
else:
ctx.system_prompt = self._config.system_prompt
logger.info("[AC-AISVC-56] No published template found, using default")
else:
ctx.system_prompt = self._config.system_prompt
# Load and inject behavior rules
if self._behavior_rule_service:
async with get_session() as session:
behavior_service = BehaviorRuleService(session)
rules = await behavior_service.get_enabled_rules(ctx.tenant_id)
if rules:
ctx.behavior_rules = [rule.rule_text for rule in rules]
behavior_text = "\n".join([f"- {rule}" for rule in ctx.behavior_rules])
ctx.system_prompt += f"\n\n行为约束:\n{behavior_text}"
logger.info(f"[AC-AISVC-84] Injected {len(rules)} behavior rules")
ctx.diagnostics["prompt_template"] = {
"source": "template" if self._prompt_template_service else "default",
"behavior_rules_count": len(ctx.behavior_rules),
}
except Exception as e:
logger.warning(f"[AC-AISVC-56] Failed to build system prompt: {e}")
ctx.system_prompt = self._config.system_prompt
ctx.diagnostics["prompt_build_error"] = str(e)
async def _generate_response(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-02] Generate response using LLM.
Step 8 of the 12-step pipeline.
"""
# Skip if flow or intent already handled
if ctx.diagnostics.get("flow_handled") or ctx.diagnostics.get("intent_handled"):
logger.info("[AC-AISVC-02] Request already handled, skipping LLM generation")
return
messages = self._build_llm_messages(ctx)
logger.info(
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
f"has_retrieval_result={ctx.retrieval_result is not None}, "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
)
if not self._llm_client:
logger.warning(
f"[AC-AISVC-02] No LLM client configured, using fallback. "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
)
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
usage={},
finish_reason="fallback",
)
ctx.diagnostics["llm_mode"] = "fallback"
ctx.diagnostics["fallback_reason"] = "no_llm_client"
return
try:
ctx.llm_response = await self._llm_client.generate(
messages=messages,
)
ctx.diagnostics["llm_mode"] = "live"
ctx.diagnostics["llm_model"] = ctx.llm_response.model
ctx.diagnostics["llm_usage"] = ctx.llm_response.usage
logger.info(
f"[AC-AISVC-02] LLM response generated: "
f"model={ctx.llm_response.model}, "
f"tokens={ctx.llm_response.usage}, "
f"content_preview={ctx.llm_response.content[:100]}..."
)
except Exception as e:
logger.error(
f"[AC-AISVC-02] LLM generation failed: {e}, "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
exc_info=True
)
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
usage={},
finish_reason="error",
metadata={"error": str(e)},
)
ctx.diagnostics["llm_error"] = str(e)
ctx.diagnostics["llm_mode"] = "fallback"
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
async def _filter_output(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-82] Step 9: Filter forbidden words in output.
"""
if not ctx.llm_response:
logger.debug("[AC-AISVC-82] No LLM response to filter")
return
if not self._output_filter:
logger.debug("[AC-AISVC-82] No output filter configured, skipping")
ctx.filtered_reply = ctx.llm_response.content
ctx.diagnostics["output_filter_enabled"] = False
return
try:
filter_result = await self._output_filter.filter(
reply=ctx.llm_response.content,
tenant_id=ctx.tenant_id,
)
ctx.filtered_reply = filter_result.filtered_text
ctx.diagnostics["output_filter"] = {
"triggered": filter_result.triggered,
"matched_words": filter_result.matched_words,
"strategy_applied": filter_result.strategy_applied,
}
if filter_result.triggered:
logger.info(
f"[AC-AISVC-82] Output filtered: words={filter_result.matched_words}, "
f"strategy={filter_result.strategy_applied}"
)
# If blocked, override confidence
if filter_result.strategy_applied == "block":
ctx.confidence_result = ConfidenceResult(
confidence=0.0,
should_transfer=True,
transfer_reason="output_blocked_by_guardrail",
is_retrieval_insufficient=False,
)
except Exception as e:
logger.warning(f"[AC-AISVC-82] Output filtering failed: {e}")
ctx.filtered_reply = ctx.llm_response.content
ctx.diagnostics["output_filter_error"] = str(e)
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
"""
[AC-AISVC-02] Build messages for LLM including system prompt and evidence.
Uses ctx.system_prompt from Step 7 (template + behavior rules).
"""
messages = []
# Use system prompt from Step 7 (template + behavior rules)
system_content = ctx.system_prompt or self._config.system_prompt
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
evidence_text = self._format_evidence(ctx.retrieval_result)
system_content += f"\n\n知识库参考内容:\n{evidence_text}"
messages.append({"role": "system", "content": system_content})
if ctx.merged_context and ctx.merged_context.messages:
messages.extend(ctx.merged_context.messages)
messages.append({"role": "user", "content": ctx.current_message})
logger.info(
f"[AC-AISVC-02] Built {len(messages)} messages for LLM: "
f"system_len={len(system_content)}, history_count={len(ctx.merged_context.messages) if ctx.merged_context else 0}"
)
logger.debug(f"[AC-AISVC-02] System prompt preview: {system_content[:500]}...")
return messages
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
"""
[AC-AISVC-17] Format retrieval hits as evidence text.
Uses shared prompt configuration for consistency.
"""
return format_evidence_for_prompt(retrieval_result.hits, max_results=5, max_content_length=500)
def _fallback_response(self, ctx: GenerationContext) -> str:
"""
[AC-AISVC-17] Generate fallback response when LLM is unavailable.
[AC-IDSMETA-20] Return fallback with structured reason code when no recall.
"""
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
return (
"根据知识库信息,我找到了一些相关内容,"
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
)
# [AC-IDSMETA-20] Record structured fallback reason code
fallback_reason_code = self._determine_fallback_reason_code(ctx)
ctx.diagnostics["fallback_reason_code"] = fallback_reason_code
logger.warning(
f"[AC-IDSMETA-20] No recall, using fallback: "
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
f"target_kbs={ctx.target_kb_ids}, "
f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, "
f"fallback_reason_code={fallback_reason_code}"
)
return (
"抱歉,我暂时无法处理您的请求。"
"请稍后重试或联系人工客服获取帮助。"
)
def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str:
"""
[AC-IDSMETA-20] Determine structured fallback reason code.
Reason codes:
- no_recall_after_metadata_filter: No results after applying metadata filters
- no_recall_no_kb: No target knowledge bases configured
- no_recall_kb_empty: Knowledge base is empty
- no_recall_low_score: Results found but below threshold
- no_recall_error: Retrieval error occurred
"""
retrieval_diag = ctx.diagnostics.get("retrieval", {})
# Check for retrieval error
if ctx.diagnostics.get("retrieval_error"):
return "no_recall_error"
# Check if metadata filters were applied
if retrieval_diag.get("applied_metadata_filters"):
return "no_recall_after_metadata_filter"
# Check if target KBs were configured
if not ctx.target_kb_ids:
return "no_recall_no_kb"
# Check if KB is empty (no candidates at all)
if retrieval_diag.get("total_candidates", 0) == 0:
return "no_recall_kb_empty"
# Results found but filtered out by score threshold
if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0:
return "no_recall_low_score"
return "no_recall_unknown"
def _calculate_confidence(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score.
Step 6 of the generation pipeline.
"""
if ctx.retrieval_result:
evidence_tokens = 0
if not ctx.retrieval_result.is_empty:
evidence_tokens = sum(
len(hit.text.split()) * 2
for hit in ctx.retrieval_result.hits
)
ctx.confidence_result = self._confidence_calculator.calculate_confidence(
retrieval_result=ctx.retrieval_result,
evidence_tokens=evidence_tokens,
)
else:
ctx.confidence_result = self._confidence_calculator.calculate_confidence_no_retrieval()
ctx.diagnostics["confidence"] = {
"score": ctx.confidence_result.confidence,
"should_transfer": ctx.confidence_result.should_transfer,
"is_insufficient": ctx.confidence_result.is_retrieval_insufficient,
}
logger.info(
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
f"{ctx.confidence_result.confidence:.3f}, "
f"should_transfer={ctx.confidence_result.should_transfer}"
)
async def _save_messages(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-13] Save user and assistant messages to Memory.
Step 11 of the 12-step pipeline.
Uses filtered_reply from Step 9.
"""
if not self._memory_service:
logger.info("[AC-AISVC-13] No memory service configured, skipping save")
return
try:
await self._memory_service.get_or_create_session(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
channel_type=ctx.channel_type,
metadata=ctx.request_metadata,
)
messages_to_save = [
{"role": "user", "content": ctx.current_message},
]
# Use filtered_reply if available, otherwise use llm_response.content
if ctx.filtered_reply:
messages_to_save.append({
"role": "assistant",
"content": ctx.filtered_reply,
})
elif ctx.llm_response:
messages_to_save.append({
"role": "assistant",
"content": ctx.llm_response.content,
})
await self._memory_service.append_messages(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
messages=messages_to_save,
)
ctx.diagnostics["messages_saved"] = len(messages_to_save)
logger.info(
f"[AC-AISVC-13] Saved {len(messages_to_save)} messages "
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-13] Failed to save messages: {e}")
ctx.diagnostics["save_error"] = str(e)
def _build_response(self, ctx: GenerationContext) -> ChatResponse:
"""
[AC-AISVC-02] Build final ChatResponse from generation context.
Step 12 of the 12-step pipeline.
Uses filtered_reply from Step 9.
"""
# Use filtered_reply if available, otherwise use llm_response.content
if ctx.filtered_reply:
reply = ctx.filtered_reply
elif ctx.llm_response:
reply = ctx.llm_response.content
else:
reply = self._fallback_response(ctx)
confidence = ctx.confidence_result.confidence if ctx.confidence_result else 0.5
should_transfer = ctx.confidence_result.should_transfer if ctx.confidence_result else True
transfer_reason = ctx.confidence_result.transfer_reason if ctx.confidence_result else None
response_metadata = {
"session_id": ctx.session_id,
"channel_type": ctx.channel_type,
"diagnostics": ctx.diagnostics,
"execution_steps": ctx.execution_steps,
}
return ChatResponse(
reply=reply,
confidence=confidence,
should_transfer=should_transfer,
transfer_reason=transfer_reason,
metadata=response_metadata,
)
async def generate_stream(
self,
tenant_id: str,
request: ChatRequest,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
12-Step Pipeline (same as generate, but with streaming LLM output):
1-7: Same as generate() up to PromptBuilder
8: LLM.stream_generate (streaming)
9: OutputFilter with streaming support
10-12: Confidence, Memory, Return
SSE Event Sequence (per design.md Section 6.2):
1. message events (multiple) - each with incremental delta
2. final event (exactly 1) - with complete response
3. connection close
OR on error:
1. message events (0 or more)
2. error event (exactly 1)
3. connection close
"""
logger.info(
f"[AC-AISVC-06] Starting 12-step streaming generation for tenant={tenant_id}, "
f"session={request.session_id}"
)
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
ctx = GenerationContext(
tenant_id=tenant_id,
session_id=request.session_id,
current_message=request.current_message,
channel_type=request.channel_type.value,
request_metadata=request.metadata,
)
try:
# Steps 1-7: Same as generate()
await self._scan_input(ctx)
await self._load_local_history(ctx)
await self._merge_context(ctx, request.history)
await self._check_active_flow(ctx)
await self._match_intent(ctx)
if self._config.enable_rag and self._retriever:
await self._retrieve_evidence(ctx)
await self._build_system_prompt(ctx)
# Step 8: LLM streaming generation
full_reply = ""
# If flow or intent already handled, stream the pre-determined response
if ctx.diagnostics.get("flow_handled") or ctx.diagnostics.get("intent_handled"):
if ctx.llm_response:
# Stream the pre-determined response character by character
import asyncio
for char in ctx.llm_response.content:
if not state_machine.can_send_message():
break
yield create_message_event(delta=char)
full_reply += char
await asyncio.sleep(0.01)
elif self._llm_client:
async for event in self._stream_from_llm(ctx, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
else:
async for event in self._stream_mock_response(ctx, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
if ctx.llm_response is None:
ctx.llm_response = LLMResponse(
content=full_reply,
model="streaming",
usage={},
finish_reason="stop",
)
# Step 9: OutputFilter (on complete reply)
await self._filter_output(ctx)
# Step 10: Confidence
self._calculate_confidence(ctx)
# Step 11: Memory
await self._save_messages(ctx)
# Step 12: Return final event
if await state_machine.transition_to_final():
final_reply = ctx.filtered_reply or full_reply
yield create_final_event(
reply=final_reply,
confidence=ctx.confidence_result.confidence if ctx.confidence_result else 0.5,
should_transfer=ctx.confidence_result.should_transfer if ctx.confidence_result else False,
transfer_reason=ctx.confidence_result.transfer_reason if ctx.confidence_result else None,
)
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
if await state_machine.transition_to_error():
yield create_error_event(
code="GENERATION_ERROR",
message=str(e),
)
finally:
await state_machine.close()
async def _stream_from_llm(
self,
ctx: GenerationContext,
state_machine: SSEStateMachine,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
"""
messages = self._build_llm_messages(ctx)
async for chunk in self._llm_client.stream_generate(messages):
if not state_machine.can_send_message():
break
if chunk.delta:
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
yield create_message_event(delta=chunk.delta)
if chunk.finish_reason:
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
break
async def _stream_mock_response(
self,
ctx: GenerationContext,
state_machine: SSEStateMachine,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
Simulates LLM-style incremental output.
"""
import asyncio
reply_parts = ["收到", "您的", "消息:", f" {ctx.current_message}"]
for part in reply_parts:
if not state_machine.can_send_message():
break
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
yield create_message_event(delta=part)
await asyncio.sleep(0.05)
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
"""Extract delta content from a message event."""
import json
try:
if event.data:
data = json.loads(event.data)
return data.get("delta", "")
except (json.JSONDecodeError, TypeError):
pass
return ""
_orchestrator_service: OrchestratorService | None = None
def get_orchestrator_service() -> OrchestratorService:
"""Get or create orchestrator service instance."""
global _orchestrator_service
if _orchestrator_service is None:
_orchestrator_service = OrchestratorService()
return _orchestrator_service
def set_orchestrator_service(service: OrchestratorService) -> None:
"""Set orchestrator service instance for testing."""
global _orchestrator_service
_orchestrator_service = service