Compare commits

...

20 Commits

Author SHA1 Message Date
MerCry 1b71c29ddb [AC-AISVC-RES-01~15] docs(progress): 新增策略路由进度文档
- 记录会话进度
- 记录实现的模块和测试
- 记录变更历史
2026-03-10 21:09:23 +08:00
MerCry 4de51bb18a [AC-AISVC-RES-01~15] test(retrieval): 新增策略路由单元测试
- 新增 test_routing_config.py 路由配置测试
  - TestStrategyType: 策略类型枚举测试
  - TestRagRuntimeMode: 运行模式枚举测试
  - TestRoutingConfig: 路由配置测试
  - TestStrategyContext: 策略上下文测试
  - TestStrategyResult: 策略结果测试

- 新增 test_strategy_router.py 策略路由器测试
  - TestRollbackRecord: 回滚记录测试
  - TestRollbackManager: 回滚管理器测试
  - TestDefaultPipeline: 默认管道测试
  - TestEnhancedPipeline: 增强管道测试
  - TestStrategyRouter: 策略路由器测试

- 新增 test_mode_router.py 模式路由器测试
  - TestComplexityAnalyzer: 复杂度分析器测试
  - TestModeRouteResult: 模式路由结果测试
  - TestModeRouter: 模式路由器测试
- 新增 test_strategy_integration.py 集成层测试
  - TestRetrievalStrategyResult: 集成结果测试
  - TestRetrievalStrategyIntegration: 集成器测试

- 79 个测试用例全部通过
2026-03-10 21:08:49 +08:00
MerCry c0688c2b13 [AC-AISVC-RES-01~15] feat(api): 新增策略管理 API端点
- 新增 strategy.py 策略 API端点
  - GET /strategy/retrieval/current - 获取当前策略配置
  - POST /strategy/retrieval/switch - 切换策略配置
  - POST /strategy/retrieval/validate - 风险配置校验
  - POST /strategy/retrieval/rollback - 强制回滚到默认策略

- 更新 __init__.py 导出新模块
- 更新 main.py 注册 API 路由
2026-03-10 21:08:07 +08:00
MerCry c628181623 [AC-AISVC-RES-01~15] feat(retrieval): 实现检索策略路由核心模块
- 新增 routing_config.py 路由配置模型
  - StrategyType: DEFAULT/ENHANCED 策略类型
  - RagRuntimeMode: DIRECT/REACT/AUTO 运行模式
  - RoutingConfig: 路由配置类
  - StrategyContext: 策略上下文
  - StrategyResult: 策略结果

- 新增 strategy_router.py 策略路由器
  - RollbackManager: 回滚管理器
  - DefaultPipeline: 默认检索管道
  - EnhancedPipeline: 增强检索管道
  - StrategyRouter: 策略路由器

- 新增 mode_router.py 模式路由器
  - ComplexityAnalyzer: 复杂度分析器
  - ModeRouter: 模式路由器

- 新增 strategy_integration.py 统一集成层
  - RetrievalStrategyIntegration: 策略集成器

- 更新 __init__.py 导出新模块
2026-03-10 21:07:01 +08:00
MerCry 2476da8957 [AC-AISVC-RES-01~15] docs(spec): 新增检索策略路由规范文档
- 新增 design.md 设计文档
- 新增 tasks.md 任务分解文档
- 更新 requirements.md 需求文档
- 更新 openapi.provider.yaml API 定义
2026-03-10 21:05:53 +08:00
MerCry 7027097513 [AC-AISVC-RES-01~15] feat(retrieval): 实现检索策略Pipeline模块
- 新增策略配置模型 (config.py)
  - GrayscaleConfig: 灰度发布配置
  - ModeRouterConfig: 模式路由配置
  - MetadataInferenceConfig: 元数据推断配置

- 新增 Pipeline 实现
  - DefaultPipeline: 复用现有 OptimizedRetriever 逻辑
  - EnhancedPipeline: Dense + Keyword + RRF 组合检索

- 新增路由器
  - StrategyRouter: 策略路由器(default/enhanced)
  - ModeRouter: 模式路由器(direct/react/auto)

- 新增 RollbackManager: 回退与审计管理器
- 新增 MetadataInferenceService: 元数据推断统一入口
- 新增单元测试 (51 passed)
2026-03-10 20:50:16 +08:00
MerCry 9f28498b97 docs: add v0.9.0 retrieval embedding strategy spec [AC-DOCS-V0.9] 2026-03-10 12:12:34 +08:00
MerCry 42f55ac4d1 chore: add utility scripts and tool definitions for KB search and metadata testing [AC-UTILS] 2026-03-10 12:11:55 +08:00
MerCry 3b354ba041 feat: add metadata discovery tool for dynamic metadata extraction [AC-METADATA-DISCOVERY] 2026-03-10 12:11:31 +08:00
MerCry 812af6c7a1 docs: update spec and docs for v0.8.0 intent hybrid routing and mid-platform features [AC-DOCS] 2026-03-10 12:10:50 +08:00
MerCry f4ca25b0d8 test: add unit tests and utility scripts for intent routing, slot management, and KB search [AC-TEST] 2026-03-10 12:10:22 +08:00
MerCry fe883cfff0 feat: update core backend services including LLM, embedding, KB, orchestrator and admin APIs [AC-AISVC-CORE] 2026-03-10 12:09:45 +08:00
MerCry 759eafb490 feat: update admin frontend with scene-slot-bundle, metadata schema, and mid-platform playground pages [AC-ADMIN-FE] 2026-03-10 12:09:00 +08:00
MerCry 9769f7ccf0 feat: add slot management system with validation, backfill, state aggregation and scene bundle support [AC-SLOT-MGMT] 2026-03-10 12:07:39 +08:00
MerCry 248a225436 feat: implement mid-platform dialogue and session management with memory recall and KB search tools [AC-IDMP-01~20] 2026-03-10 12:06:57 +08:00
MerCry d78b72ca93 feat: enhance agent orchestrator with runtime hardening and tool governance [AC-MARH-01~12] 2026-03-10 12:06:15 +08:00
MerCry 66902cd7c1 feat: implement hybrid intent routing with RuleMatcher, SemanticMatcher, LlmJudge and FusionPolicy [AC-AISVC-111~125] 2026-03-10 12:05:35 +08:00
MerCry 0dfc60935d feat: add knowledge base query result files for grade-specific subjects [AC-KB-DATA] 2026-03-10 11:57:42 +08:00
MerCry 3969322d34 fix: use correct attribute name system_instruction for version content [AC-IDSMETA-16] 2026-03-06 11:08:49 +08:00
MerCry b832f372d1 fix: resolve metadata field mapping and return current_content in prompt template update [AC-IDSMETA-16] 2026-03-06 11:06:08 +08:00
540 changed files with 40331 additions and 537 deletions

View File

@ -13,6 +13,7 @@
<span class="logo-text">AI Robot</span>
</div>
<nav class="main-nav">
<div class="nav-row">
<router-link to="/dashboard" class="nav-item" :class="{ active: isActive('/dashboard') }">
<el-icon><Odometer /></el-icon>
<span>控制台</span>
@ -29,7 +30,8 @@
<el-icon><Monitor /></el-icon>
<span>会话监控</span>
</router-link>
<div class="nav-divider"></div>
</div>
<div class="nav-row">
<router-link to="/admin/embedding" class="nav-item" :class="{ active: isActive('/admin/embedding') }">
<el-icon><Connection /></el-icon>
<span>嵌入模型</span>
@ -66,6 +68,11 @@
<el-icon><Grid /></el-icon>
<span>槽位定义</span>
</router-link>
<router-link to="/admin/scene-slot-bundles" class="nav-item" :class="{ active: isActive('/admin/scene-slot-bundles') }">
<el-icon><Collection /></el-icon>
<span>场景槽位包</span>
</router-link>
</div>
</nav>
</div>
<div class="header-right">
@ -98,7 +105,7 @@ import { ref, onMounted } from 'vue'
import { useRoute } from 'vue-router'
import { useTenantStore } from '@/stores/tenant'
import { getTenantList, type Tenant } from '@/api/tenant'
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare, Document, Aim, Share, Warning, Setting, ChatLineRound, Grid } from '@element-plus/icons-vue'
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare, Document, Aim, Share, Warning, Setting, ChatLineRound, Grid, Collection } from '@element-plus/icons-vue'
import { ElMessage } from 'element-plus'
const route = useRoute()
@ -168,7 +175,8 @@ onMounted(() => {
align-items: center;
justify-content: space-between;
padding: 0 24px;
height: 60px;
height: auto;
min-height: 60px;
background-color: var(--bg-secondary, #FFFFFF);
border-bottom: 1px solid var(--border-color, #E2E8F0);
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.04);
@ -208,6 +216,12 @@ onMounted(() => {
}
.main-nav {
display: flex;
flex-direction: column;
gap: 4px;
}
.nav-row {
display: flex;
align-items: center;
gap: 4px;

View File

@ -36,3 +36,19 @@ export function deleteDocument(docId: string) {
method: 'delete'
})
}
export function batchUploadDocuments(data: FormData) {
return request({
url: '/admin/kb/documents/batch-upload',
method: 'post',
data
})
}
export function jsonBatchUpload(kbId: string, data: FormData) {
return request({
url: `/admin/kb/${kbId}/documents/json-batch`,
method: 'post',
data
})
}

View File

@ -11,7 +11,7 @@ import type {
export const metadataSchemaApi = {
list: (status?: 'draft' | 'active' | 'deprecated', fieldRole?: FieldRole) =>
request<MetadataFieldListResponse>({
request<MetadataFieldDefinition[]>({
method: 'GET',
url: '/admin/metadata-schemas',
params: {

View File

@ -35,6 +35,8 @@ export interface ToolCallTrace {
error_code?: string
args_digest?: string
result_digest?: string
arguments?: Record<string, unknown>
result?: unknown
}
export interface DialogueResponse {
@ -219,3 +221,16 @@ export function sendSharedMessage(shareToken: string, userMessage: string): Prom
data: { user_message: userMessage }
})
}
export interface CancelFlowResponse {
success: boolean
message: string
session_id: string
}
export function cancelActiveFlow(sessionId: string): Promise<CancelFlowResponse> {
return request({
url: `/mid/sessions/${sessionId}/cancel-flow`,
method: 'post'
})
}

View File

@ -0,0 +1,48 @@
import request from '@/utils/request'
import type {
SceneSlotBundle,
SceneSlotBundleWithDetails,
SceneSlotBundleCreateRequest,
SceneSlotBundleUpdateRequest,
} from '@/types/scene-slot-bundle'
export const sceneSlotBundleApi = {
list: (status?: string) =>
request<SceneSlotBundle[]>({
method: 'GET',
url: '/admin/scene-slot-bundles',
params: status ? { status } : {},
}),
get: (id: string) =>
request<SceneSlotBundleWithDetails>({
method: 'GET',
url: `/admin/scene-slot-bundles/${id}`,
}),
getBySceneKey: (sceneKey: string) =>
request<SceneSlotBundle>({
method: 'GET',
url: `/admin/scene-slot-bundles/by-scene/${sceneKey}`,
}),
create: (data: SceneSlotBundleCreateRequest) =>
request<SceneSlotBundle>({
method: 'POST',
url: '/admin/scene-slot-bundles',
data,
}),
update: (id: string, data: SceneSlotBundleUpdateRequest) =>
request<SceneSlotBundle>({
method: 'PUT',
url: `/admin/scene-slot-bundles/${id}`,
data,
}),
delete: (id: string) =>
request<void>({
method: 'DELETE',
url: `/admin/scene-slot-bundles/${id}`,
}),
}

View File

@ -99,17 +99,19 @@ const deprecatedFields = computed(() => {
})
const visibleFields = computed(() => {
// Show all fields except deprecated in edit mode, so users can change draft to active
if (props.isNewObject) {
return activeFields.value
}
return allFields.value.filter(f => f.status !== 'draft')
return allFields.value.filter(f => f.status !== 'deprecated')
})
const loadFields = async () => {
loading.value = true
try {
const res = await metadataSchemaApi.getByScope(props.scope, props.showDeprecated)
allFields.value = res.items || []
// Handle both formats: direct array or { items: array }
allFields.value = Array.isArray(res) ? res : (res.items || [])
emit('fields-loaded', allFields.value)
applyDefaults()

View File

@ -17,12 +17,6 @@ const routes: Array<RouteRecordRaw> = [
component: () => import('@/views/dashboard/index.vue'),
meta: { title: '控制台' }
},
{
path: '/kb',
name: 'KBManagement',
component: () => import('@/views/kb/index.vue'),
meta: { title: '知识库管理' }
},
{
path: '/rag-lab',
name: 'RagLab',
@ -71,6 +65,12 @@ const routes: Array<RouteRecordRaw> = [
component: () => import('@/views/admin/slot-definition/index.vue'),
meta: { title: '槽位定义管理' }
},
{
path: '/admin/scene-slot-bundles',
name: 'SceneSlotBundle',
component: () => import('@/views/admin/scene-slot-bundle/index.vue'),
meta: { title: '场景槽位包管理' }
},
{
path: '/admin/intent-rules',
name: 'IntentRule',

View File

@ -15,6 +15,7 @@ export interface MetadataFieldDefinition {
scope: MetadataScope[]
is_filterable: boolean
is_rank_feature: boolean
usage_description?: string
field_roles: FieldRole[]
status: MetadataFieldStatus
created_at?: string
@ -31,6 +32,7 @@ export interface MetadataFieldCreateRequest {
scope: MetadataScope[]
is_filterable?: boolean
is_rank_feature?: boolean
usage_description?: string
field_roles?: FieldRole[]
status: MetadataFieldStatus
}
@ -43,6 +45,7 @@ export interface MetadataFieldUpdateRequest {
scope?: MetadataScope[]
is_filterable?: boolean
is_rank_feature?: boolean
usage_description?: string
field_roles?: FieldRole[]
status?: MetadataFieldStatus
}

View File

@ -0,0 +1,80 @@
/**
* Scene Slot Bundle Types.
* [AC-SCENE-SLOT-01] -
*/
export type SceneSlotBundleStatus = 'draft' | 'active' | 'deprecated'
export type AskBackOrder = 'priority' | 'required_first' | 'parallel'
export interface SceneSlotBundle {
id: string
tenant_id: string
scene_key: string
scene_name: string
description: string | null
required_slots: string[]
optional_slots: string[]
slot_priority: string[] | null
completion_threshold: number
ask_back_order: AskBackOrder
status: SceneSlotBundleStatus
version: number
created_at: string | null
updated_at: string | null
}
export interface SlotDetail {
slot_key: string
type: string
required: boolean
ask_back_prompt: string | null
linked_field_id: string | null
}
export interface SceneSlotBundleWithDetails extends SceneSlotBundle {
required_slot_details: SlotDetail[]
optional_slot_details: SlotDetail[]
}
export interface SceneSlotBundleCreateRequest {
scene_key: string
scene_name: string
description?: string
required_slots?: string[]
optional_slots?: string[]
slot_priority?: string[]
completion_threshold?: number
ask_back_order?: AskBackOrder
status?: SceneSlotBundleStatus
}
export interface SceneSlotBundleUpdateRequest {
scene_name?: string
description?: string
required_slots?: string[]
optional_slots?: string[]
slot_priority?: string[]
completion_threshold?: number
ask_back_order?: AskBackOrder
status?: SceneSlotBundleStatus
}
export const SCENE_SLOT_BUNDLE_STATUS_OPTIONS = [
{ value: 'draft', label: '草稿', description: '配置中,未启用' },
{ value: 'active', label: '已启用', description: '运行时可用' },
{ value: 'deprecated', label: '已废弃', description: '不再使用' },
]
export const ASK_BACK_ORDER_OPTIONS = [
{ value: 'priority', label: '按优先级', description: '按 slot_priority 定义的顺序追问' },
{ value: 'required_first', label: '必填优先', description: '优先追问必填槽位' },
{ value: 'parallel', label: '并行追问', description: '一次追问多个缺失槽位' },
]
export function getStatusLabel(status: SceneSlotBundleStatus): string {
return SCENE_SLOT_BUNDLE_STATUS_OPTIONS.find(o => o.value === status)?.label || status
}
export function getAskBackOrderLabel(order: AskBackOrder): string {
return ASK_BACK_ORDER_OPTIONS.find(o => o.value === order)?.label || order
}

View File

@ -39,6 +39,10 @@ export interface FlowStep {
intent_description?: string
script_constraints?: string[]
expected_variables?: string[]
allowed_kb_ids?: string[]
preferred_kb_ids?: string[]
kb_query_hint?: string
max_kb_calls_per_step?: number
}
export interface NextCondition {

View File

@ -2,13 +2,31 @@ export type SlotType = 'string' | 'number' | 'boolean' | 'enum' | 'array_enum'
export type ExtractStrategy = 'rule' | 'llm' | 'user_input'
export type SlotSource = 'user_confirmed' | 'rule_extracted' | 'llm_inferred' | 'default'
/**
* [AC-MRS-07-UPGRADE]
*/
export type ExtractFailureType =
| 'EXTRACT_EMPTY'
| 'EXTRACT_PARSE_FAIL'
| 'EXTRACT_VALIDATION_FAIL'
| 'EXTRACT_RUNTIME_ERROR'
export interface SlotDefinition {
id: string
tenant_id: string
slot_key: string
display_name?: string
description?: string
type: SlotType
required: boolean
/**
* [AC-MRS-07-UPGRADE]
*/
extract_strategy?: ExtractStrategy
/**
* [AC-MRS-07-UPGRADE]
*/
extract_strategies?: ExtractStrategy[]
validation_rule?: string
ask_back_prompt?: string
default_value?: string | number | boolean | string[]
@ -29,8 +47,17 @@ export interface LinkedField {
export interface SlotDefinitionCreateRequest {
tenant_id?: string
slot_key: string
display_name?: string
description?: string
type: SlotType
required: boolean
/**
* [AC-MRS-07-UPGRADE]
*/
extract_strategies?: ExtractStrategy[]
/**
* [AC-MRS-07-UPGRADE]
*/
extract_strategy?: ExtractStrategy
validation_rule?: string
ask_back_prompt?: string
@ -39,8 +66,17 @@ export interface SlotDefinitionCreateRequest {
}
export interface SlotDefinitionUpdateRequest {
display_name?: string
description?: string
type?: SlotType
required?: boolean
/**
* [AC-MRS-07-UPGRADE]
*/
extract_strategies?: ExtractStrategy[]
/**
* [AC-MRS-07-UPGRADE]
*/
extract_strategy?: ExtractStrategy
validation_rule?: string
ask_back_prompt?: string
@ -56,6 +92,31 @@ export interface RuntimeSlotValue {
updated_at?: string
}
/**
* [AC-MRS-07-UPGRADE]
*/
export interface StrategyStepResult {
strategy: ExtractStrategy
success: boolean
value?: string | number | boolean | string[]
failure_type?: ExtractFailureType
failure_reason?: string
execution_time_ms: number
}
/**
* [AC-MRS-07-UPGRADE]
*/
export interface StrategyChainResult {
slot_key: string
success: boolean
final_value?: string | number | boolean | string[]
final_strategy?: ExtractStrategy
steps: StrategyStepResult[]
total_execution_time_ms: number
ask_back_prompt?: string
}
export const SLOT_TYPE_OPTIONS = [
{ value: 'string', label: '文本' },
{ value: 'number', label: '数字' },
@ -69,3 +130,42 @@ export const EXTRACT_STRATEGY_OPTIONS = [
{ value: 'llm', label: 'LLM 推断', description: '通过大语言模型推断槽位值' },
{ value: 'user_input', label: '用户输入', description: '通过追问提示语让用户主动输入' }
]
/**
* [AC-MRS-07-UPGRADE]
*/
export function validateExtractStrategies(strategies: ExtractStrategy[]): { valid: boolean; error?: string } {
if (!strategies || strategies.length === 0) {
return { valid: true } // 空数组视为有效
}
const validStrategies = new Set(['rule', 'llm', 'user_input'])
// 检查重复
const uniqueStrategies = new Set(strategies)
if (uniqueStrategies.size !== strategies.length) {
return { valid: false, error: '提取策略链中不允许重复的策略' }
}
// 检查有效性
const invalid = strategies.filter(s => !validStrategies.has(s))
if (invalid.length > 0) {
return { valid: false, error: `无效的提取策略: ${invalid.join(', ')}` }
}
return { valid: true }
}
/**
* [AC-MRS-07-UPGRADE]
* 使 extract_strategies extract_strategy
*/
export function getEffectiveStrategies(slot: SlotDefinition): ExtractStrategy[] {
if (slot.extract_strategies && slot.extract_strategies.length > 0) {
return slot.extract_strategies
}
if (slot.extract_strategy) {
return [slot.extract_strategy]
}
return []
}

View File

@ -4,7 +4,7 @@ import { useTenantStore } from '@/stores/tenant'
const service = axios.create({
baseURL: import.meta.env.VITE_APP_BASE_API || '/api',
timeout: 60000
timeout: 180000
})
service.interceptors.request.use(

View File

@ -5,6 +5,14 @@
<el-icon><Upload /></el-icon>
上传文档
</el-button>
<el-button type="success" @click="handleBatchUploadClick">
<el-icon><Upload /></el-icon>
批量上传
</el-button>
<el-button type="warning" @click="handleJsonUploadClick">
<el-icon><Upload /></el-icon>
JSON上传
</el-button>
<input
ref="fileInputRef"
type="file"
@ -12,6 +20,20 @@
style="display: none"
@change="handleFileSelect"
/>
<input
ref="batchFileInputRef"
type="file"
accept=".zip"
style="display: none"
@change="handleBatchFileSelect"
/>
<input
ref="jsonFileInputRef"
type="file"
accept=".json,.jsonl,.txt"
style="display: none"
@change="handleJsonFileSelect"
/>
</div>
<el-table :data="documents" v-loading="loading" stripe>
@ -165,6 +187,8 @@ const pagination = ref({
})
const fileInputRef = ref<HTMLInputElement>()
const batchFileInputRef = ref<HTMLInputElement>()
const jsonFileInputRef = ref<HTMLInputElement>()
const uploadDialogVisible = ref(false)
const editDialogVisible = ref(false)
const uploading = ref(false)
@ -261,6 +285,14 @@ const handleUploadClick = () => {
fileInputRef.value?.click()
}
const handleBatchUploadClick = () => {
batchFileInputRef.value?.click()
}
const handleJsonUploadClick = () => {
jsonFileInputRef.value?.click()
}
const handleFileSelect = (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
@ -289,6 +321,140 @@ const handleFileSelect = (event: Event) => {
}
}
const handleBatchFileSelect = async (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
if (!file) return
if (!file.name.toLowerCase().endsWith('.zip')) {
ElMessage.error('请上传 .zip 格式的压缩包')
if (batchFileInputRef.value) {
batchFileInputRef.value.value = ''
}
return
}
const maxSize = 100 * 1024 * 1024
if (file.size > maxSize) {
ElMessage.error('压缩包大小不能超过 100MB')
if (batchFileInputRef.value) {
batchFileInputRef.value.value = ''
}
return
}
try {
loading.value = true
const formData = new FormData()
formData.append('file', file)
formData.append('kb_id', props.kbId)
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
const response = await fetch(`${baseUrl}/admin/kb/documents/batch-upload`, {
method: 'POST',
headers: {
'X-Tenant-Id': tenantStore.currentTenantId,
'X-API-Key': apiKey
},
body: formData
})
const result = await response.json()
if (result.success) {
const { total, succeeded, failed, results } = result
ElMessage.success(`批量上传完成!成功: ${succeeded}, 失败: ${failed}, 总计: ${total}`)
// ID
const jobIds: string[] = []
for (const item of results) {
if (item.status === 'created' && item.jobId) {
jobIds.push(item.jobId)
}
}
//
if (jobIds.length > 0) {
pollJobStatusSequential(jobIds)
}
emit('upload-success')
loadDocuments()
} else {
ElMessage.error(result.message || '批量上传失败')
}
} catch (error) {
ElMessage.error('批量上传失败,请重试')
console.error('Batch upload error:', error)
} finally {
loading.value = false
if (batchFileInputRef.value) {
batchFileInputRef.value.value = ''
}
}
}
const handleJsonFileSelect = async (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
if (!file) return
const fileName = file.name.toLowerCase()
if (!fileName.endsWith('.json') && !fileName.endsWith('.jsonl') && !fileName.endsWith('.txt')) {
ElMessage.error('请上传 .json、.jsonl 或 .txt 格式的文件')
if (jsonFileInputRef.value) {
jsonFileInputRef.value.value = ''
}
return
}
try {
loading.value = true
ElMessage.info('正在上传 JSON 文件...')
const formData = new FormData()
formData.append('file', file)
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
const response = await fetch(`${baseUrl}/admin/kb/${props.kbId}/documents/json-batch?tenant_id=${tenantStore.currentTenantId}`, {
method: 'POST',
headers: {
'X-Tenant-Id': tenantStore.currentTenantId,
'X-API-Key': apiKey
},
body: formData
})
const result = await response.json()
if (result.success) {
ElMessage.success(`JSON 批量上传成功!成功: ${result.succeeded}, 失败: ${result.failed}`)
if (result.valid_metadata_fields) {
console.log('有效的元数据字段:', result.valid_metadata_fields)
}
if (result.failed > 0) {
const failedItems = result.results.filter((r: any) => !r.success)
console.warn('失败的项目:', failedItems)
}
emit('upload-success')
loadDocuments()
} else {
ElMessage.error(result.message || 'JSON 批量上传失败')
}
} catch (error) {
ElMessage.error('JSON 批量上传失败,请重试')
console.error('JSON batch upload error:', error)
} finally {
loading.value = false
if (jsonFileInputRef.value) {
jsonFileInputRef.value.value = ''
}
}
}
const handleUpload = async () => {
if (!selectedFile.value) {
ElMessage.warning('请选择文件')
@ -310,11 +476,13 @@ const handleUpload = async () => {
formData.append('kb_id', props.kbId)
formData.append('metadata', JSON.stringify(uploadForm.value.metadata))
const baseUrl = import.meta.env.VITE_API_BASE_URL || ''
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
const response = await fetch(`${baseUrl}/admin/kb/documents`, {
method: 'POST',
headers: {
'X-Tenant-Id': tenantStore.currentTenantId
'X-Tenant-Id': tenantStore.currentTenantId,
'X-API-Key': apiKey
},
body: formData
})
@ -356,12 +524,14 @@ const handleSaveMetadata = async () => {
saving.value = true
try {
const baseUrl = import.meta.env.VITE_API_BASE_URL || ''
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
const response = await fetch(`${baseUrl}/admin/kb/documents/${currentEditDoc.value.docId}/metadata`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'X-Tenant-Id': tenantStore.currentTenantId
'X-Tenant-Id': tenantStore.currentTenantId,
'X-API-Key': apiKey
},
body: JSON.stringify({ metadata: editForm.value.metadata })
})
@ -398,14 +568,64 @@ const pollJobStatus = async (jobId: string) => {
ElMessage.error(`文档处理失败: ${job.errorMsg || '未知错误'}`)
loadDocuments()
} else {
setTimeout(poll, 3000)
setTimeout(poll, 10000) // 10
}
} catch (error) {
} catch (error: any) {
if (error.status === 429) {
console.warn('请求过于频繁,稍后重试')
setTimeout(poll, 15000) // 15
} else {
console.error('轮询任务状态失败', error)
}
}
}
setTimeout(poll, 2000)
setTimeout(poll, 5000) // 5
}
//
const pollJobStatusSequential = async (jobIds: string[]) => {
const pendingJobs = new Set(jobIds)
const maxPolls = 60
let pollCount = 0
const pollAll = async () => {
if (pollCount >= maxPolls || pendingJobs.size === 0) return
pollCount++
const completedJobs: string[] = []
for (const jobId of pendingJobs) {
try {
const job: IndexJob = await getIndexJob(jobId)
if (job.status === 'completed') {
completedJobs.push(jobId)
} else if (job.status === 'failed') {
completedJobs.push(jobId)
console.error(`文档处理失败: ${job.errorMsg || '未知错误'}`)
}
} catch (error: any) {
if (error.status === 429) {
console.warn('请求过于频繁,稍后重试')
break //
} else {
console.error('轮询任务状态失败', error)
}
}
}
//
completedJobs.forEach(jobId => pendingJobs.delete(jobId))
if (pendingJobs.size > 0) {
setTimeout(pollAll, 10000) // 10
} else {
ElMessage.success('所有文档处理完成')
loadDocuments()
}
}
setTimeout(pollAll, 5000) // 5
}
const handleDelete = async (row: DocumentWithMetadata) => {
@ -438,6 +658,8 @@ onMounted(() => {
.list-header {
margin-bottom: 16px;
display: flex;
gap: 12px;
}
.pagination-wrapper {

View File

@ -211,6 +211,14 @@
</el-form-item>
</el-col>
</el-row>
<el-form-item label="用途说明">
<el-input
v-model="formData.usage_description"
type="textarea"
:rows="2"
placeholder="描述该字段的业务用途和使用场景"
/>
</el-form-item>
<el-form-item label="默认值">
<el-input v-model="formData.default_value" placeholder="可选默认值" />
</el-form-item>
@ -308,6 +316,7 @@ const formData = reactive({
scope: [] as MetadataScope[],
is_filterable: true,
is_rank_feature: false,
usage_description: '',
field_roles: [] as FieldRole[],
status: 'draft' as MetadataFieldStatus
})
@ -354,8 +363,11 @@ const getRoleLabel = (role: FieldRole) => {
const fetchFields = async () => {
loading.value = true
try {
const res = await metadataSchemaApi.list(filterStatus.value || undefined, filterRole.value || undefined)
fields.value = res.items || []
const res: any = await metadataSchemaApi.list(filterStatus.value || undefined, filterRole.value || undefined)
// [] / {items: []} / {schemas: []} / {data: []}
fields.value = Array.isArray(res)
? res
: (res?.items || res?.schemas || res?.data || [])
} catch (error: any) {
ElMessage.error(error.response?.data?.message || '获取元数据字段失败')
} finally {
@ -376,6 +388,7 @@ const handleCreate = () => {
scope: [],
is_filterable: true,
is_rank_feature: false,
usage_description: '',
field_roles: [],
status: 'draft'
})
@ -395,6 +408,7 @@ const handleEdit = (field: MetadataFieldDefinition) => {
scope: [...field.scope],
is_filterable: field.is_filterable,
is_rank_feature: field.is_rank_feature,
usage_description: field.usage_description || '',
field_roles: field.field_roles || [],
status: field.status
})
@ -472,6 +486,7 @@ const handleSubmit = async () => {
scope: formData.scope,
is_filterable: formData.is_filterable,
is_rank_feature: formData.is_rank_feature,
usage_description: formData.usage_description || undefined,
field_roles: formData.field_roles,
status: formData.status
}

View File

@ -30,6 +30,12 @@
<el-form-item>
<el-button :loading="switchingMode" @click="handleSwitchMode">应用模式</el-button>
</el-form-item>
<el-form-item>
<el-button type="warning" :loading="cancellingFlow" @click="handleCancelFlow">取消流程</el-button>
</el-form-item>
<el-form-item>
<el-button type="primary" @click="handleNewSession">新建会话</el-button>
</el-form-item>
</el-form>
</el-card>
@ -98,14 +104,36 @@
<div class="json-panel">
<div class="json-title">tool_calls</div>
<div class="json-content">
<el-button
class="copy-btn"
size="small"
text
type="primary"
@click="copyJson(lastTrace.tool_calls || [], 'tool_calls')"
>
复制
</el-button>
<pre>{{ JSON.stringify(lastTrace.tool_calls || [], null, 2) }}</pre>
</div>
</div>
<div class="json-panel">
<div class="json-title">metrics_snapshot</div>
<div class="json-content">
<el-button
class="copy-btn"
size="small"
text
type="primary"
@click="copyJson(lastTrace.metrics_snapshot || {}, 'metrics_snapshot')"
>
复制
</el-button>
<pre>{{ JSON.stringify(lastTrace.metrics_snapshot || {}, null, 2) }}</pre>
</div>
</div>
</div>
<el-empty v-else description="暂无 Trace" :image-size="90" />
</el-card>
@ -180,6 +208,7 @@ import {
createPublicShareToken,
listShares,
deleteShare,
cancelActiveFlow,
type DialogueMessage,
type DialogueResponse,
type SessionMode,
@ -205,6 +234,7 @@ const rollbackToLegacy = ref(false)
const sending = ref(false)
const switchingMode = ref(false)
const reporting = ref(false)
const cancellingFlow = ref(false)
const userInput = ref('')
const conversation = ref<ChatItem[]>([])
@ -221,11 +251,11 @@ const shareForm = ref({
const now = () => new Date().toLocaleTimeString()
const tagType = (role: ChatRole) => {
const tagType = (role: ChatRole): 'primary' | 'success' | 'info' | 'warning' | 'danger' => {
if (role === 'user') return 'info'
if (role === 'assistant') return 'success'
if (role === 'human') return 'warning'
return ''
return 'primary'
}
const toHistory = (): DialogueMessage[] => {
@ -289,6 +319,32 @@ const handleSwitchMode = async () => {
}
}
const handleCancelFlow = async () => {
cancellingFlow.value = true
try {
const result = await cancelActiveFlow(sessionId.value)
if (result.success) {
ElMessage.success(result.message)
} else {
ElMessage.warning(result.message)
}
} catch {
ElMessage.error('取消流程失败')
} finally {
cancellingFlow.value = false
}
}
const handleNewSession = () => {
const newSessionId = `sess_${Date.now()}`
const newUserId = `user_${Date.now()}`
sessionId.value = newSessionId
userId.value = newUserId
conversation.value = []
lastTrace.value = null
ElMessage.success(`新会话已创建: ${newSessionId}`)
}
const handleReportMessages = async () => {
if (!conversation.value.length) {
ElMessage.warning('暂无可上报消息')
@ -370,6 +426,12 @@ const copyShareUrl = (url: string) => {
ElMessage.success('链接已复制到剪贴板')
}
const copyJson = (data: unknown, name: string) => {
const jsonStr = JSON.stringify(data, null, 2)
navigator.clipboard.writeText(jsonStr)
ElMessage.success(`${name} 已复制到剪贴板`)
}
const formatShareExpires = (expiresAt: string) => {
const expires = new Date(expiresAt)
return expires.toLocaleDateString()
@ -462,6 +524,23 @@ onMounted(() => {
margin-bottom: 4px;
}
.json-content {
position: relative;
}
.json-content .copy-btn {
position: absolute;
top: 4px;
right: 4px;
z-index: 10;
background: var(--el-bg-color);
opacity: 0.8;
}
.json-content .copy-btn:hover {
opacity: 1;
}
pre {
background: var(--el-fill-color-lighter);
padding: 8px;

View File

@ -0,0 +1,660 @@
<template>
<div class="scene-slot-bundle-page">
<div class="page-header">
<div class="header-content">
<div class="title-section">
<h1 class="page-title">场景槽位包管理</h1>
<p class="page-desc">配置场景与槽位的映射关系定义每个场景需要采集的槽位集合[AC-SCENE-SLOT-01]</p>
</div>
<div class="header-actions">
<el-select v-model="filterStatus" placeholder="按状态筛选" clearable style="width: 140px;">
<el-option
v-for="opt in SCENE_SLOT_BUNDLE_STATUS_OPTIONS"
:key="opt.value"
:label="opt.label"
:value="opt.value"
/>
</el-select>
<el-button type="primary" @click="handleCreate">
<el-icon><Plus /></el-icon>
新建场景槽位包
</el-button>
</div>
</div>
</div>
<el-card shadow="hover" class="bundle-card" v-loading="loading">
<el-table :data="bundles" stripe style="width: 100%">
<el-table-column prop="scene_key" label="场景标识" min-width="140">
<template #default="{ row }">
<code class="scene-key">{{ row.scene_key }}</code>
</template>
</el-table-column>
<el-table-column prop="scene_name" label="场景名称" min-width="120" />
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="getStatusTagType(row.status)" size="small">
{{ getStatusLabel(row.status) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="required_slots" label="必填槽位" min-width="160">
<template #default="{ row }">
<div class="slot-tags">
<el-tag
v-for="slotKey in row.required_slots.slice(0, 3)"
:key="slotKey"
size="small"
type="danger"
class="slot-tag"
>
{{ slotKey }}
</el-tag>
<el-tag v-if="row.required_slots.length > 3" size="small" type="info">
+{{ row.required_slots.length - 3 }}
</el-tag>
</div>
</template>
</el-table-column>
<el-table-column prop="optional_slots" label="可选槽位" min-width="140">
<template #default="{ row }">
<div class="slot-tags" v-if="row.optional_slots.length > 0">
<el-tag
v-for="slotKey in row.optional_slots.slice(0, 3)"
:key="slotKey"
size="small"
type="info"
class="slot-tag"
>
{{ slotKey }}
</el-tag>
<el-tag v-if="row.optional_slots.length > 3" size="small" type="info">
+{{ row.optional_slots.length - 3 }}
</el-tag>
</div>
<span v-else class="no-value">-</span>
</template>
</el-table-column>
<el-table-column prop="completion_threshold" label="完成阈值" width="100">
<template #default="{ row }">
<span>{{ (row.completion_threshold * 100).toFixed(0) }}%</span>
</template>
</el-table-column>
<el-table-column prop="ask_back_order" label="追问策略" width="100">
<template #default="{ row }">
<span>{{ getAskBackOrderLabel(row.ask_back_order) }}</span>
</template>
</el-table-column>
<el-table-column label="操作" width="150" fixed="right">
<template #default="{ row }">
<el-button type="primary" link size="small" @click="handleEdit(row)">
<el-icon><Edit /></el-icon>
编辑
</el-button>
<el-button type="danger" link size="small" @click="handleDelete(row)">
<el-icon><Delete /></el-icon>
删除
</el-button>
</template>
</el-table-column>
</el-table>
<el-empty v-if="!loading && bundles.length === 0" description="暂无场景槽位包" />
</el-card>
<el-dialog
v-model="dialogVisible"
:title="isEdit ? '编辑场景槽位包' : '新建场景槽位包'"
width="800px"
:close-on-click-modal="false"
destroy-on-close
>
<el-form :model="formData" :rules="formRules" ref="formRef" label-width="100px">
<el-row :gutter="20">
<el-col :span="12">
<el-form-item label="场景标识" prop="scene_key">
<el-input
v-model="formData.scene_key"
placeholder="如open_consult, refund_apply"
:disabled="isEdit"
/>
<div class="field-hint">唯一标识创建后不可修改</div>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="场景名称" prop="scene_name">
<el-input v-model="formData.scene_name" placeholder="如:开放咨询、退款申请" />
</el-form-item>
</el-col>
</el-row>
<el-form-item label="场景描述" prop="description">
<el-input
v-model="formData.description"
type="textarea"
:rows="2"
placeholder="描述该场景的业务背景和用途"
/>
</el-form-item>
<el-row :gutter="20">
<el-col :span="12">
<el-form-item label="状态" prop="status">
<el-select v-model="formData.status" style="width: 100%;">
<el-option
v-for="opt in SCENE_SLOT_BUNDLE_STATUS_OPTIONS"
:key="opt.value"
:label="opt.label"
:value="opt.value"
>
<div class="status-option">
<span>{{ opt.label }}</span>
<span class="status-desc">{{ opt.description }}</span>
</div>
</el-option>
</el-select>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="完成阈值" prop="completion_threshold">
<el-slider
v-model="formData.completion_threshold"
:min="0"
:max="1"
:step="0.1"
:format-tooltip="(val: number) => `${(val * 100).toFixed(0)}%`"
/>
<div class="field-hint">必填槽位填充比例达到此值视为完成</div>
</el-form-item>
</el-col>
</el-row>
<el-form-item label="必填槽位" prop="required_slots">
<el-select
v-model="formData.required_slots"
multiple
filterable
style="width: 100%;"
placeholder="选择必填槽位"
>
<el-option
v-for="slot in availableSlots"
:key="slot.id"
:label="`${slot.slot_key} (${slot.ask_back_prompt || '无追问提示'})`"
:value="slot.slot_key"
:disabled="formData.optional_slots.includes(slot.slot_key)"
>
<div class="slot-option">
<span class="slot-key-label">{{ slot.slot_key }}</span>
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
</div>
</el-option>
</el-select>
<div class="field-hint">缺失必填槽位时会触发追问不会直接进行 KB 检索</div>
</el-form-item>
<el-form-item label="可选槽位" prop="optional_slots">
<el-select
v-model="formData.optional_slots"
multiple
filterable
style="width: 100%;"
placeholder="选择可选槽位"
>
<el-option
v-for="slot in availableSlots"
:key="slot.id"
:label="slot.slot_key"
:value="slot.slot_key"
:disabled="formData.required_slots.includes(slot.slot_key)"
>
<div class="slot-option">
<span class="slot-key-label">{{ slot.slot_key }}</span>
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
</div>
</el-option>
</el-select>
<div class="field-hint">可选槽位用于增强检索效果缺失时不阻塞流程</div>
</el-form-item>
<el-form-item label="槽位优先级" prop="slot_priority">
<div class="priority-editor">
<div class="priority-header">
<span class="priority-hint">拖拽调整槽位采集和追问的优先级顺序</span>
<el-button
v-if="formData.slot_priority && formData.slot_priority.length > 0"
type="danger"
link
size="small"
@click="formData.slot_priority = null"
>
清空
</el-button>
</div>
<div v-if="allSlotsForPriority.length > 0" class="priority-list-wrapper">
<draggable
v-model="allSlotsForPriority"
item-key="slot_key"
handle=".drag-handle"
class="priority-list"
ghost-class="ghost"
>
<template #item="{ element, index }">
<div class="priority-item" :class="{ 'in-priority': isInPriority(element.slot_key) }">
<el-icon class="drag-handle"><Rank /></el-icon>
<span class="priority-order">{{ index + 1 }}</span>
<el-tag
size="small"
:type="formData.required_slots.includes(element.slot_key) ? 'danger' : 'info'"
>
{{ element.slot_key }}
</el-tag>
<el-tag size="small" type="warning" v-if="formData.required_slots.includes(element.slot_key)">
必填
</el-tag>
</div>
</template>
</draggable>
</div>
<div v-else class="priority-empty">
<el-text type="info">请先添加必填或可选槽位</el-text>
</div>
</div>
</el-form-item>
<el-form-item label="追问策略" prop="ask_back_order">
<el-radio-group v-model="formData.ask_back_order">
<el-radio-button
v-for="opt in ASK_BACK_ORDER_OPTIONS"
:key="opt.value"
:value="opt.value"
>
{{ opt.label }}
</el-radio-button>
</el-radio-group>
<div class="field-hint">
{{ ASK_BACK_ORDER_OPTIONS.find(o => o.value === formData.ask_back_order)?.description }}
</div>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="dialogVisible = false">取消</el-button>
<el-button type="primary" :loading="submitting" @click="handleSubmit">
{{ isEdit ? '保存' : '创建' }}
</el-button>
</template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted, watch, computed } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Plus, Edit, Delete, Rank } from '@element-plus/icons-vue'
import type { FormInstance, FormRules } from 'element-plus'
import draggable from 'vuedraggable'
import { sceneSlotBundleApi } from '@/api/scene-slot-bundle'
import { slotDefinitionApi } from '@/api/slot-definition'
import {
SCENE_SLOT_BUNDLE_STATUS_OPTIONS,
ASK_BACK_ORDER_OPTIONS,
type SceneSlotBundle,
type SceneSlotBundleCreateRequest,
type SceneSlotBundleUpdateRequest,
type SceneSlotBundleStatus,
type AskBackOrder,
getStatusLabel,
getAskBackOrderLabel,
} from '@/types/scene-slot-bundle'
import type { SlotDefinition } from '@/types/slot-definition'
const loading = ref(false)
const bundles = ref<SceneSlotBundle[]>([])
const availableSlots = ref<SlotDefinition[]>([])
const filterStatus = ref<SceneSlotBundleStatus | ''>('')
const dialogVisible = ref(false)
const isEdit = ref(false)
const submitting = ref(false)
const formRef = ref<FormInstance>()
const formData = reactive({
id: '',
scene_key: '',
scene_name: '',
description: '',
required_slots: [] as string[],
optional_slots: [] as string[],
slot_priority: null as string[] | null,
completion_threshold: 1.0,
ask_back_order: 'priority' as AskBackOrder,
status: 'draft' as SceneSlotBundleStatus,
})
const formRules: FormRules = {
scene_key: [
{ required: true, message: '请输入场景标识', trigger: 'blur' },
{ pattern: /^[a-z][a-z0-9_]*$/, message: '以小写字母开头,仅允许小写字母、数字、下划线', trigger: 'blur' }
],
scene_name: [{ required: true, message: '请输入场景名称', trigger: 'blur' }],
status: [{ required: true, message: '请选择状态', trigger: 'change' }],
}
const getStatusTagType = (status: SceneSlotBundleStatus): 'success' | 'warning' | 'info' => {
const typeMap: Record<SceneSlotBundleStatus, 'success' | 'warning' | 'info'> = {
'active': 'success',
'draft': 'warning',
'deprecated': 'info',
}
return typeMap[status] || 'info'
}
const allSlotsForPriority = computed({
get: () => {
const allKeys = [...formData.required_slots, ...formData.optional_slots]
const priority = formData.slot_priority || allKeys
const orderedSlots = priority
.filter(key => allKeys.includes(key))
.map(key => availableSlots.value.find(s => s.slot_key === key) || { slot_key: key })
const remainingKeys = allKeys.filter(key => !priority.includes(key))
const remainingSlots = remainingKeys
.map(key => availableSlots.value.find(s => s.slot_key === key) || { slot_key: key })
return [...orderedSlots, ...remainingSlots]
},
set: (value: { slot_key: string }[]) => {
formData.slot_priority = value.map(s => s.slot_key)
}
})
const isInPriority = (slotKey: string): boolean => {
if (!formData.slot_priority) return true
return formData.slot_priority.includes(slotKey)
}
const fetchBundles = async () => {
loading.value = true
try {
const res = await sceneSlotBundleApi.list(filterStatus.value || undefined)
bundles.value = res || []
} catch (error: any) {
ElMessage.error(error.response?.data?.message || '获取场景槽位包列表失败')
} finally {
loading.value = false
}
}
const fetchAvailableSlots = async () => {
try {
const res = await slotDefinitionApi.list()
availableSlots.value = res || []
} catch (error: any) {
console.error('获取槽位定义失败', error)
}
}
const handleCreate = () => {
isEdit.value = false
Object.assign(formData, {
id: '',
scene_key: '',
scene_name: '',
description: '',
required_slots: [],
optional_slots: [],
slot_priority: null,
completion_threshold: 1.0,
ask_back_order: 'priority',
status: 'draft',
})
dialogVisible.value = true
}
const handleEdit = (bundle: SceneSlotBundle) => {
isEdit.value = true
Object.assign(formData, {
id: bundle.id,
scene_key: bundle.scene_key,
scene_name: bundle.scene_name,
description: bundle.description || '',
required_slots: [...bundle.required_slots],
optional_slots: [...bundle.optional_slots],
slot_priority: bundle.slot_priority ? [...bundle.slot_priority] : null,
completion_threshold: bundle.completion_threshold,
ask_back_order: bundle.ask_back_order,
status: bundle.status,
})
dialogVisible.value = true
}
const handleDelete = async (bundle: SceneSlotBundle) => {
try {
await ElMessageBox.confirm(
`确定要删除场景槽位包「${bundle.scene_name}」吗?`,
'删除确认',
{ type: 'warning' }
)
await sceneSlotBundleApi.delete(bundle.id)
ElMessage.success('删除成功')
fetchBundles()
} catch (error: any) {
if (error !== 'cancel') {
ElMessage.error(error.response?.data?.message || '删除失败')
}
}
}
const handleSubmit = async () => {
if (!formRef.value) return
await formRef.value.validate(async (valid) => {
if (!valid) return
submitting.value = true
try {
const data: SceneSlotBundleCreateRequest | SceneSlotBundleUpdateRequest = {
scene_name: formData.scene_name,
description: formData.description || undefined,
required_slots: formData.required_slots.length > 0 ? formData.required_slots : undefined,
optional_slots: formData.optional_slots.length > 0 ? formData.optional_slots : undefined,
slot_priority: formData.slot_priority || undefined,
completion_threshold: formData.completion_threshold,
ask_back_order: formData.ask_back_order,
status: formData.status,
}
if (isEdit.value) {
await sceneSlotBundleApi.update(formData.id, data as SceneSlotBundleUpdateRequest)
ElMessage.success('更新成功')
} else {
const createData = data as SceneSlotBundleCreateRequest
createData.scene_key = formData.scene_key
await sceneSlotBundleApi.create(createData)
ElMessage.success('创建成功')
}
dialogVisible.value = false
fetchBundles()
} catch (error: any) {
ElMessage.error(error.response?.data?.message || '操作失败')
} finally {
submitting.value = false
}
})
}
watch(filterStatus, () => {
fetchBundles()
})
onMounted(() => {
fetchBundles()
fetchAvailableSlots()
})
</script>
<style scoped lang="scss">
.scene-slot-bundle-page {
padding: 20px;
}
.page-header {
margin-bottom: 24px;
}
.header-content {
display: flex;
justify-content: space-between;
align-items: flex-start;
}
.title-section {
.page-title {
font-size: 24px;
font-weight: 600;
margin: 0 0 8px 0;
color: var(--el-text-color-primary);
}
.page-desc {
font-size: 14px;
color: var(--el-text-color-secondary);
margin: 0;
}
}
.header-actions {
display: flex;
align-items: center;
gap: 12px;
}
.bundle-card {
border-radius: 8px;
}
.scene-key {
padding: 2px 6px;
background-color: var(--el-fill-color-light);
border-radius: 4px;
font-family: monospace;
font-size: 12px;
}
.slot-tags {
display: flex;
flex-wrap: wrap;
gap: 4px;
}
.slot-tag {
margin: 0;
}
.no-value {
color: var(--el-text-color-placeholder);
}
.field-hint {
margin-top: 4px;
font-size: 12px;
color: var(--el-text-color-secondary);
}
.status-option {
display: flex;
flex-direction: column;
gap: 2px;
.status-desc {
font-size: 12px;
color: var(--el-text-color-secondary);
}
}
.slot-option {
display: flex;
align-items: center;
gap: 8px;
.slot-key-label {
font-weight: 500;
font-family: monospace;
}
}
.priority-editor {
border: 1px solid var(--el-border-color);
border-radius: 4px;
padding: 12px;
background-color: var(--el-fill-color-light);
}
.priority-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
}
.priority-hint {
font-size: 12px;
color: var(--el-text-color-secondary);
}
.priority-list-wrapper {
max-height: 300px;
overflow-y: auto;
}
.priority-list {
display: flex;
flex-direction: column;
gap: 8px;
}
.priority-item {
display: flex;
align-items: center;
gap: 8px;
padding: 8px 12px;
background-color: var(--el-bg-color);
border-radius: 4px;
border: 1px solid var(--el-border-color-light);
cursor: move;
transition: all 0.2s;
&:hover {
border-color: var(--el-color-primary-light-5);
}
&.in-priority {
background-color: var(--el-color-primary-light-9);
}
}
.drag-handle {
cursor: move;
color: var(--el-text-color-secondary);
font-size: 16px;
}
.priority-order {
width: 20px;
text-align: center;
font-weight: 600;
color: var(--el-text-color-secondary);
}
.priority-empty {
text-align: center;
padding: 20px;
color: var(--el-text-color-secondary);
}
.ghost {
opacity: 0.5;
background: var(--el-color-primary-light-8);
}
</style>

View File

@ -185,11 +185,23 @@
v-model="element.expected_variables"
multiple
filterable
allow-create
default-first-option
placeholder="输入变量名后回车添加"
placeholder="选择期望提取的槽位"
style="width: 100%"
/>
>
<el-option
v-for="slot in availableSlots"
:key="slot.id"
:label="`${slot.slot_key} (${slot.type})`"
:value="slot.slot_key"
>
<div class="slot-option">
<span class="slot-key">{{ slot.slot_key }}</span>
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
<el-tag size="small" v-if="slot.required" type="danger">必填</el-tag>
</div>
</el-option>
</el-select>
<div class="field-hint">期望变量必须引用已定义的槽位步骤完成时会检查这些槽位是否已填充</div>
</el-form-item>
</template>
@ -218,13 +230,100 @@
v-model="element.expected_variables"
multiple
filterable
allow-create
default-first-option
placeholder="输入变量名后回车添加"
placeholder="选择期望提取的槽位"
style="width: 100%"
>
<el-option
v-for="slot in availableSlots"
:key="slot.id"
:label="`${slot.slot_key} (${slot.type})`"
:value="slot.slot_key"
>
<div class="slot-option">
<span class="slot-key">{{ slot.slot_key }}</span>
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
<el-tag size="small" v-if="slot.required" type="danger">必填</el-tag>
</div>
</el-option>
</el-select>
<div class="field-hint">期望变量必须引用已定义的槽位步骤完成时会检查这些槽位是否已填充</div>
</el-form-item>
</template>
<el-divider content-position="left">知识库范围</el-divider>
<div class="kb-binding-section">
<el-form-item label="允许的知识库">
<el-select
v-model="element.allowed_kb_ids"
multiple
filterable
clearable
placeholder="选择允许检索的知识库(为空则使用默认策略)"
style="width: 100%"
>
<el-option
v-for="kb in availableKnowledgeBases"
:key="kb.id"
:label="kb.name"
:value="kb.id"
>
<div class="kb-option">
<span>{{ kb.name }}</span>
<el-tag size="small" :type="getKbTypeTagType(kb.kbType)">{{ getKbTypeLabel(kb.kbType) }}</el-tag>
</div>
</el-option>
</el-select>
<div class="field-hint">限制该步骤只能从选定的知识库中检索信息为空则使用默认检索策略</div>
</el-form-item>
<el-form-item label="优先知识库">
<el-select
v-model="element.preferred_kb_ids"
multiple
filterable
clearable
placeholder="选择优先检索的知识库"
style="width: 100%"
>
<el-option
v-for="kb in availableKnowledgeBases"
:key="kb.id"
:label="kb.name"
:value="kb.id"
:disabled="element.allowed_kb_ids && element.allowed_kb_ids.length > 0 && !element.allowed_kb_ids.includes(kb.id)"
>
<div class="kb-option">
<span>{{ kb.name }}</span>
<el-tag size="small" :type="getKbTypeTagType(kb.kbType)">{{ getKbTypeLabel(kb.kbType) }}</el-tag>
</div>
</el-option>
</el-select>
<div class="field-hint">检索时优先搜索这些知识库</div>
</el-form-item>
<el-row :gutter="16">
<el-col :span="16">
<el-form-item label="检索提示">
<el-input
v-model="element.kb_query_hint"
placeholder="可选:描述本步骤的检索意图,帮助提高检索准确性"
/>
</el-form-item>
</el-col>
<el-col :span="8">
<el-form-item label="最大检索次数">
<el-input-number
v-model="element.max_kb_calls_per_step"
:min="1"
:max="5"
placeholder="默认1"
style="width: 100%"
/>
</el-form-item>
</template>
</el-col>
</el-row>
</div>
<el-row :gutter="16">
<el-col :span="8">
@ -362,17 +461,24 @@ import {
createScriptFlow,
updateScriptFlow,
deleteScriptFlow,
getScriptFlow
getScriptFlow,
} from '@/api/script-flow'
import { slotDefinitionApi } from '@/api/slot-definition'
import { listKnowledgeBases } from '@/api/knowledge-base'
import { MetadataForm } from '@/components/metadata'
import { TIMEOUT_ACTION_OPTIONS, SCRIPT_MODE_OPTIONS } from '@/types/script-flow'
import type { ScriptFlow, ScriptFlowDetail, ScriptFlowCreate, ScriptFlowUpdate, FlowStep, ScriptMode } from '@/types/script-flow'
import type { SlotDefinition } from '@/types/slot-definition'
import type { KnowledgeBase } from '@/types/knowledge-base'
import { KB_TYPE_MAP } from '@/types/knowledge-base'
import FlowPreview from './components/FlowPreview.vue'
import SimulateDialog from './components/SimulateDialog.vue'
import ConstraintManager from './components/ConstraintManager.vue'
const loading = ref(false)
const flows = ref<ScriptFlow[]>([])
const availableSlots = ref<SlotDefinition[]>([])
const availableKnowledgeBases = ref<KnowledgeBase[]>([])
const dialogVisible = ref(false)
const previewDrawer = ref(false)
const simulateDialogVisible = ref(false)
@ -425,16 +531,55 @@ const loadFlows = async () => {
}
}
const loadAvailableSlots = async () => {
try {
const res = await slotDefinitionApi.list()
availableSlots.value = res || []
} catch (error) {
console.error('加载槽位定义失败', error)
availableSlots.value = []
}
}
const loadAvailableKnowledgeBases = async () => {
try {
const res = await listKnowledgeBases({ is_enabled: true })
availableKnowledgeBases.value = res.data || []
} catch (error) {
console.error('加载知识库列表失败', error)
availableKnowledgeBases.value = []
}
}
const getKbTypeLabel = (kbType: string): string => {
return KB_TYPE_MAP[kbType]?.label || kbType
}
const getKbTypeTagType = (kbType: string): string => {
const typeMap: Record<string, string> = {
product: 'primary',
faq: 'success',
script: 'warning',
policy: 'danger',
general: 'info'
}
return typeMap[kbType] || 'info'
}
const handleCreate = () => {
isEdit.value = false
currentEditId.value = ''
formData.value = defaultFormData()
loadAvailableSlots()
loadAvailableKnowledgeBases()
dialogVisible.value = true
}
const handleEdit = async (row: ScriptFlow) => {
isEdit.value = true
currentEditId.value = row.id
await loadAvailableSlots()
await loadAvailableKnowledgeBases()
try {
const detail = await getScriptFlow(row.id)
formData.value = {
@ -444,7 +589,11 @@ const handleEdit = async (row: ScriptFlow) => {
...step,
script_mode: step.script_mode || 'fixed',
script_constraints: step.script_constraints || [],
expected_variables: step.expected_variables || []
expected_variables: step.expected_variables || [],
allowed_kb_ids: step.allowed_kb_ids || [],
preferred_kb_ids: step.preferred_kb_ids || [],
kb_query_hint: step.kb_query_hint || '',
max_kb_calls_per_step: step.max_kb_calls_per_step || null
})),
is_enabled: detail.is_enabled,
metadata: detail.metadata || {}
@ -739,4 +888,35 @@ onMounted(() => {
border-radius: 4px;
margin-bottom: 8px;
}
.slot-option {
display: flex;
align-items: center;
gap: 8px;
}
.slot-key {
font-weight: 500;
font-family: monospace;
}
.field-hint {
margin-top: 4px;
font-size: 12px;
color: var(--el-text-color-secondary);
}
.kb-binding-section {
padding: 12px;
background-color: var(--el-fill-color-lighter);
border-radius: 6px;
margin-bottom: 12px;
}
.kb-option {
display: flex;
align-items: center;
justify-content: space-between;
width: 100%;
}
</style>

View File

@ -23,7 +23,16 @@
<el-table :data="slots" stripe style="width: 100%">
<el-table-column prop="slot_key" label="槽位标识" min-width="140">
<template #default="{ row }">
<div class="slot-key-cell">
<code class="slot-key">{{ row.slot_key }}</code>
<span v-if="row.display_name" class="display-name">{{ row.display_name }}</span>
</div>
</template>
</el-table-column>
<el-table-column prop="description" label="说明" min-width="180">
<template #default="{ row }">
<span v-if="row.description" class="slot-description">{{ row.description }}</span>
<span v-else class="no-value">-</span>
</template>
</el-table-column>
<el-table-column prop="type" label="类型" width="100">
@ -38,11 +47,20 @@
</el-tag>
</template>
</el-table-column>
<el-table-column prop="extract_strategy" label="提取策略" width="120">
<!-- [AC-MRS-07-UPGRADE] 策略链展示 -->
<el-table-column prop="extract_strategies" label="提取策略链" min-width="180">
<template #default="{ row }">
<el-tag v-if="row.extract_strategy" size="small">
{{ getExtractStrategyLabel(row.extract_strategy) }}
<div v-if="getEffectiveStrategies(row).length > 0" class="strategy-chain">
<el-tag
v-for="(strategy, idx) in getEffectiveStrategies(row)"
:key="idx"
size="small"
:type="getStrategyTagType(strategy)"
class="strategy-tag"
>
{{ idx + 1 }}. {{ getExtractStrategyLabel(strategy) }}
</el-tag>
</div>
<span v-else class="no-value">-</span>
</template>
</el-table-column>
@ -80,7 +98,7 @@
<el-dialog
v-model="dialogVisible"
:title="isEdit ? '编辑槽位定义' : '新建槽位定义'"
width="650px"
width="700px"
:close-on-click-modal="false"
destroy-on-close
>
@ -96,6 +114,25 @@
<div class="field-hint">仅允许小写字母数字下划线</div>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="槽位名称" prop="display_name">
<el-input
v-model="formData.display_name"
placeholder="如:当前年级"
/>
<div class="field-hint">给运营/教研看的中文名</div>
</el-form-item>
</el-col>
</el-row>
<el-form-item label="槽位说明" prop="description">
<el-input
v-model="formData.description"
type="textarea"
:rows="2"
placeholder="解释这个槽位采集什么、用于哪里,如:用于课程分层推荐和知识库过滤"
/>
</el-form-item>
<el-row :gutter="20">
<el-col :span="12">
<el-form-item label="槽位类型" prop="type">
<el-select v-model="formData.type" style="width: 100%;">
@ -108,21 +145,74 @@
</el-select>
</el-form-item>
</el-col>
</el-row>
<el-row :gutter="20">
<el-col :span="12">
<el-form-item label="是否必填" prop="required">
<el-switch v-model="formData.required" />
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="提取策略" prop="extract_strategy">
<el-select v-model="formData.extract_strategy" style="width: 100%;" clearable placeholder="选择提取策略">
</el-row>
<!-- [AC-MRS-07-UPGRADE] 提取策略链配置 -->
<el-form-item label="提取策略链" prop="extract_strategies">
<div class="strategy-chain-editor">
<div class="strategy-chain-header">
<span class="chain-hint">按优先级排序系统将依次尝试直到成功</span>
<el-button
v-if="formData.extract_strategies.length > 0"
type="danger"
link
size="small"
@click="clearStrategies"
>
清空
</el-button>
</div>
<div v-if="formData.extract_strategies.length > 0" class="strategy-chain-list">
<draggable
v-model="formData.extract_strategies"
:item-key="(item: ExtractStrategy) => item"
handle=".drag-handle"
class="strategy-list"
>
<template #item="{ element, index }">
<div class="strategy-item">
<el-icon class="drag-handle"><Rank /></el-icon>
<span class="strategy-order">{{ index + 1 }}</span>
<el-tag size="small" :type="getStrategyTagType(element)">
{{ getExtractStrategyLabel(element) }}
</el-tag>
<el-button
type="danger"
link
size="small"
class="remove-btn"
@click="removeStrategy(index)"
>
<el-icon><Close /></el-icon>
</el-button>
</div>
</template>
</draggable>
</div>
<div v-else class="strategy-empty">
<el-text type="info">暂无策略请从下方添加</el-text>
</div>
<div class="strategy-add-section">
<el-select
v-model="selectedStrategy"
placeholder="选择要添加的策略"
style="width: 200px;"
clearable
>
<el-option
v-for="opt in EXTRACT_STRATEGY_OPTIONS"
v-for="opt in availableStrategies"
:key="opt.value"
:label="opt.label"
:value="opt.value"
:value="opt.value as ExtractStrategy"
:disabled="isStrategySelected(opt.value as ExtractStrategy)"
>
<div class="extract-option">
<span>{{ opt.label }}</span>
@ -130,9 +220,22 @@
</div>
</el-option>
</el-select>
<el-button
type="primary"
:disabled="!selectedStrategy"
@click="addStrategy"
>
<el-icon><Plus /></el-icon>
添加
</el-button>
</div>
<div v-if="strategyError" class="strategy-error">
<el-text type="danger">{{ strategyError }}</el-text>
</div>
</div>
</el-form-item>
</el-col>
</el-row>
<el-form-item label="关联字段" prop="linked_field_id">
<el-select
v-model="formData.linked_field_id"
@ -186,10 +289,11 @@
</template>
<script setup lang="ts">
import { ref, reactive, onMounted, watch } from 'vue'
import { ref, reactive, onMounted, watch, computed } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Plus, Edit, Delete } from '@element-plus/icons-vue'
import { Plus, Edit, Delete, Rank, Close } from '@element-plus/icons-vue'
import type { FormInstance, FormRules } from 'element-plus'
import draggable from 'vuedraggable'
import { slotDefinitionApi } from '@/api/slot-definition'
import { metadataSchemaApi } from '@/api/metadata-schema'
import {
@ -199,7 +303,9 @@ import {
type SlotDefinitionCreateRequest,
type SlotDefinitionUpdateRequest,
type SlotType,
type ExtractStrategy
type ExtractStrategy,
validateExtractStrategies,
getEffectiveStrategies
} from '@/types/slot-definition'
import type { MetadataFieldDefinition } from '@/types/metadata'
@ -212,12 +318,19 @@ const isEdit = ref(false)
const submitting = ref(false)
const formRef = ref<FormInstance>()
// [AC-MRS-07-UPGRADE]
const selectedStrategy = ref<ExtractStrategy | ''>('')
const strategyError = ref('')
const formData = reactive({
id: '',
slot_key: '',
display_name: '',
description: '',
type: 'string' as SlotType,
required: false,
extract_strategy: '' as ExtractStrategy | '',
// [AC-MRS-07-UPGRADE] 使
extract_strategies: [] as ExtractStrategy[],
validation_rule: '',
ask_back_prompt: '',
default_value: '',
@ -233,6 +346,42 @@ const formRules: FormRules = {
required: [{ required: true, message: '请选择是否必填', trigger: 'change' }]
}
// [AC-MRS-07-UPGRADE]
const availableStrategies = computed(() => {
return EXTRACT_STRATEGY_OPTIONS
})
// [AC-MRS-07-UPGRADE]
const isStrategySelected = (strategy: ExtractStrategy) => {
return formData.extract_strategies.includes(strategy)
}
// [AC-MRS-07-UPGRADE]
const addStrategy = () => {
if (!selectedStrategy.value) return
if (isStrategySelected(selectedStrategy.value)) {
strategyError.value = '该策略已存在于链中'
return
}
formData.extract_strategies.push(selectedStrategy.value)
selectedStrategy.value = ''
strategyError.value = ''
}
// [AC-MRS-07-UPGRADE]
const removeStrategy = (index: number) => {
formData.extract_strategies.splice(index, 1)
strategyError.value = ''
}
// [AC-MRS-07-UPGRADE]
const clearStrategies = () => {
formData.extract_strategies = []
strategyError.value = ''
}
const getTypeLabel = (type: SlotType) => {
return SLOT_TYPE_OPTIONS.find(o => o.value === type)?.label || type
}
@ -241,6 +390,16 @@ const getExtractStrategyLabel = (strategy: ExtractStrategy) => {
return EXTRACT_STRATEGY_OPTIONS.find(o => o.value === strategy)?.label || strategy
}
// [AC-MRS-07-UPGRADE]
const getStrategyTagType = (strategy: ExtractStrategy): any => {
const typeMap: Record<ExtractStrategy, any> = {
'rule': 'success',
'llm': 'warning',
'user_input': 'info'
}
return typeMap[strategy] || 'info'
}
const fetchSlots = async () => {
loading.value = true
try {
@ -256,7 +415,8 @@ const fetchSlots = async () => {
const fetchSlotFields = async () => {
try {
const res = await metadataSchemaApi.getByRole('slot', true)
slotFields.value = res.items || []
//
slotFields.value = Array.isArray(res) ? res : (res.items || [])
} catch (error: any) {
console.error('获取槽位角色字段失败', error)
}
@ -267,14 +427,19 @@ const handleCreate = () => {
Object.assign(formData, {
id: '',
slot_key: '',
display_name: '',
description: '',
type: 'string',
required: false,
extract_strategy: '',
// [AC-MRS-07-UPGRADE]
extract_strategies: [],
validation_rule: '',
ask_back_prompt: '',
default_value: '',
linked_field_id: ''
})
selectedStrategy.value = ''
strategyError.value = ''
dialogVisible.value = true
}
@ -283,14 +448,19 @@ const handleEdit = (slot: SlotDefinition) => {
Object.assign(formData, {
id: slot.id,
slot_key: slot.slot_key,
display_name: slot.display_name || '',
description: slot.description || '',
type: slot.type,
required: slot.required,
extract_strategy: slot.extract_strategy || '',
// [AC-MRS-07-UPGRADE] 使
extract_strategies: [...getEffectiveStrategies(slot)],
validation_rule: slot.validation_rule || '',
ask_back_prompt: slot.ask_back_prompt || '',
default_value: slot.default_value ?? '',
linked_field_id: slot.linked_field_id || ''
})
selectedStrategy.value = ''
strategyError.value = ''
dialogVisible.value = true
}
@ -317,13 +487,25 @@ const handleSubmit = async () => {
await formRef.value.validate(async (valid) => {
if (!valid) return
// [AC-MRS-07-UPGRADE]
if (formData.extract_strategies.length > 0) {
const validation = validateExtractStrategies(formData.extract_strategies)
if (!validation.valid) {
strategyError.value = validation.error || '策略链校验失败'
return
}
}
submitting.value = true
try {
const data: SlotDefinitionCreateRequest | SlotDefinitionUpdateRequest = {
slot_key: formData.slot_key,
display_name: formData.display_name || undefined,
description: formData.description || undefined,
type: formData.type,
required: formData.required,
extract_strategy: formData.extract_strategy || undefined,
// [AC-MRS-07-UPGRADE]
extract_strategies: formData.extract_strategies.length > 0 ? formData.extract_strategies : undefined,
validation_rule: formData.validation_rule || undefined,
ask_back_prompt: formData.ask_back_prompt || undefined,
linked_field_id: formData.linked_field_id || undefined
@ -409,6 +591,27 @@ onMounted(() => {
font-size: 12px;
}
.slot-key-cell {
display: flex;
flex-direction: column;
gap: 2px;
}
.display-name {
font-size: 12px;
color: var(--el-text-color-secondary);
}
.slot-description {
color: var(--el-text-color-regular);
font-size: 13px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
display: block;
max-width: 180px;
}
.no-value {
color: var(--el-text-color-placeholder);
}
@ -439,6 +642,91 @@ onMounted(() => {
color: var(--el-text-color-secondary);
}
// [AC-MRS-07-UPGRADE]
.strategy-chain {
display: flex;
flex-wrap: wrap;
gap: 4px;
}
.strategy-tag {
margin-right: 4px;
}
.strategy-chain-editor {
border: 1px solid var(--el-border-color);
border-radius: 4px;
padding: 12px;
background-color: var(--el-fill-color-light);
}
.strategy-chain-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
}
.chain-hint {
font-size: 12px;
color: var(--el-text-color-secondary);
}
.strategy-chain-list {
margin-bottom: 16px;
}
.strategy-list {
display: flex;
flex-direction: column;
gap: 8px;
}
.strategy-item {
display: flex;
align-items: center;
gap: 8px;
padding: 8px 12px;
background-color: var(--el-bg-color);
border-radius: 4px;
border: 1px solid var(--el-border-color-light);
}
.drag-handle {
cursor: move;
color: var(--el-text-color-secondary);
font-size: 16px;
}
.strategy-order {
width: 20px;
text-align: center;
font-weight: 600;
color: var(--el-text-color-secondary);
}
.remove-btn {
margin-left: auto;
}
.strategy-empty {
text-align: center;
padding: 20px;
color: var(--el-text-color-secondary);
}
.strategy-add-section {
display: flex;
gap: 8px;
align-items: center;
padding-top: 12px;
border-top: 1px dashed var(--el-border-color);
}
.strategy-error {
margin-top: 8px;
}
.extract-option {
display: flex;
flex-direction: column;

View File

@ -410,6 +410,24 @@
<p>配置知识库意图规则等的动态元数据字段定义</p>
</div>
</div>
<div class="help-item" @click="navigateTo('/admin/slot-definitions')">
<div class="help-icon success">
<el-icon><Grid /></el-icon>
</div>
<div class="help-text">
<h4>槽位定义</h4>
<p>定义槽位类型提取策略校验规则和追问提示</p>
</div>
</div>
<div class="help-item" @click="navigateTo('/admin/scene-slot-bundles')">
<div class="help-icon warning">
<el-icon><Collection /></el-icon>
</div>
<div class="help-text">
<h4>场景槽位包</h4>
<p>配置场景与槽位的映射关系定义每个场景需要采集的槽位</p>
</div>
</div>
</div>
</el-card>
</el-col>
@ -420,7 +438,7 @@
<script setup lang="ts">
import { ref, reactive, computed, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { FolderOpened, Document, ChatDotSquare, Monitor, Cpu, InfoFilled, Connection, Timer, DataLine, Aim, DocumentCopy, Share, Warning, Setting } from '@element-plus/icons-vue'
import { FolderOpened, Document, ChatDotSquare, Monitor, Cpu, InfoFilled, Connection, Timer, DataLine, Aim, DocumentCopy, Share, Warning, Setting, Grid, Collection } from '@element-plus/icons-vue'
import { getDashboardStats, type DashboardStats } from '@/api/dashboard'
const router = useRouter()

View File

@ -11,6 +11,14 @@
<el-icon><Upload /></el-icon>
上传文档
</el-button>
<el-button type="success" @click="handleBatchUploadClick">
<el-icon><Upload /></el-icon>
批量上传
</el-button>
<el-button type="warning" @click="handleJsonUploadClick">
<el-icon><Upload /></el-icon>
JSON上传
</el-button>
</div>
</div>
</div>
@ -85,6 +93,8 @@
</el-dialog>
<input ref="fileInput" type="file" style="display: none" @change="handleFileChange" />
<input ref="batchFileInput" type="file" accept=".zip" style="display: none" @change="handleBatchFileChange" />
<input ref="jsonFileInput" type="file" accept=".json,.jsonl,.txt" style="display: none" @change="handleJsonFileChange" />
</div>
</template>
@ -92,7 +102,7 @@
import { ref, onMounted, onUnmounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Upload, Document, View, Delete } from '@element-plus/icons-vue'
import { uploadDocument, listDocuments, getIndexJob, deleteDocument } from '@/api/kb'
import { uploadDocument, listDocuments, getIndexJob, deleteDocument, batchUploadDocuments, jsonBatchUpload } from '@/api/kb'
interface DocumentItem {
docId: string
@ -242,11 +252,21 @@ onUnmounted(() => {
})
const fileInput = ref<HTMLInputElement | null>(null)
const batchFileInput = ref<HTMLInputElement | null>(null)
const jsonFileInput = ref<HTMLInputElement | null>(null)
const handleUploadClick = () => {
fileInput.value?.click()
}
const handleBatchUploadClick = () => {
batchFileInput.value?.click()
}
const handleJsonUploadClick = () => {
jsonFileInput.value?.click()
}
const handleFileChange = async (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
@ -281,6 +301,111 @@ const handleFileChange = async (event: Event) => {
target.value = ''
}
}
const handleBatchFileChange = async (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
if (!file) return
if (!file.name.toLowerCase().endsWith('.zip')) {
ElMessage.error('请上传 .zip 格式的压缩包')
target.value = ''
return
}
const formData = new FormData()
formData.append('file', file)
formData.append('kb_id', 'kb_default')
try {
loading.value = true
const res: any = await batchUploadDocuments(formData)
const { total, succeeded, failed, results } = res
if (succeeded > 0) {
ElMessage.success(`批量上传成功!成功: ${succeeded}, 失败: ${failed}, 总计: ${total}`)
for (const result of results) {
if (result.status === 'created') {
const newDoc: DocumentItem = {
docId: result.docId,
name: result.fileName,
status: 'pending',
jobId: result.jobId,
createTime: new Date().toLocaleString('zh-CN')
}
tableData.value.unshift(newDoc)
startPolling(result.jobId)
}
}
} else {
ElMessage.error('批量上传失败,请检查压缩包格式')
}
if (failed > 0) {
const failedItems = results.filter((r: any) => r.status === 'failed')
console.error('Failed uploads:', failedItems)
}
fetchDocuments()
} catch (error: any) {
ElMessage.error(error.message || '批量上传失败')
console.error('Batch upload error:', error)
} finally {
loading.value = false
target.value = ''
}
}
const handleJsonFileChange = async (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
if (!file) return
const fileName = file.name.toLowerCase()
if (!fileName.endsWith('.json') && !fileName.endsWith('.jsonl') && !fileName.endsWith('.txt')) {
ElMessage.error('请上传 .json、.jsonl 或 .txt 格式的文件')
target.value = ''
return
}
const kbId = '75c465fe-277d-455d-a30b-4b168adcc03b'
const formData = new FormData()
formData.append('file', file)
formData.append('tenant_id', 'szmp@ash@2026')
try {
loading.value = true
ElMessage.info('正在上传 JSON 文件...')
const res: any = await jsonBatchUpload(kbId, formData)
if (res.success) {
ElMessage.success(`JSON 批量上传成功!成功: ${res.succeeded}, 失败: ${res.failed}`)
if (res.valid_metadata_fields) {
console.log('有效的元数据字段:', res.valid_metadata_fields)
}
if (res.failed > 0) {
const failedItems = res.results.filter((r: any) => !r.success)
console.warn('失败的项目:', failedItems)
}
fetchDocuments()
} else {
ElMessage.error('JSON 批量上传失败')
}
} catch (error: any) {
ElMessage.error(error.message || 'JSON 批量上传失败')
console.error('JSON batch upload error:', error)
} finally {
loading.value = false
target.value = ''
}
}
</script>
<style scoped>
@ -324,6 +449,7 @@ const handleFileChange = async (event: Event) => {
.header-actions {
display: flex;
align-items: center;
gap: 12px;
}
.table-card {

View File

@ -2,6 +2,8 @@
Admin API routes for AI Service management.
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
[AC-MRS-07,08,16] Slot definition management endpoints.
[AC-SCENE-SLOT-01] Scene slot bundle management endpoints.
[AC-AISVC-RES-01~15] Retrieval strategy management endpoints.
"""
from app.api.admin.api_key import router as api_key_router
@ -18,6 +20,8 @@ from app.api.admin.metadata_schema import router as metadata_schema_router
from app.api.admin.monitoring import router as monitoring_router
from app.api.admin.prompt_templates import router as prompt_templates_router
from app.api.admin.rag import router as rag_router
from app.api.admin.retrieval_strategy import router as retrieval_strategy_router
from app.api.admin.scene_slot_bundle import router as scene_slot_bundle_router
from app.api.admin.script_flows import router as script_flows_router
from app.api.admin.sessions import router as sessions_router
from app.api.admin.slot_definition import router as slot_definition_router
@ -38,6 +42,8 @@ __all__ = [
"monitoring_router",
"prompt_templates_router",
"rag_router",
"retrieval_strategy_router",
"scene_slot_bundle_router",
"script_flows_router",
"sessions_router",
"slot_definition_router",

View File

@ -2,6 +2,8 @@
Intent Rule Management API.
[AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints.
[AC-AISVC-96] Intent rule testing endpoint.
[AC-AISVC-116] Fusion config management endpoints.
[AC-AISVC-114] Intent vector generation endpoint.
"""
import logging
@ -14,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.models.entities import IntentRuleCreate, IntentRuleUpdate
from app.services.intent.models import DEFAULT_FUSION_CONFIG, FusionConfig
from app.services.intent.rule_service import IntentRuleService
from app.services.intent.tester import IntentRuleTester
@ -21,6 +24,8 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"])
_fusion_config = FusionConfig()
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header."""
@ -204,3 +209,109 @@ async def test_rule(
result = await tester.test_rule(rule, [body.message], all_rules)
return result.to_dict()
class FusionConfigUpdate(BaseModel):
"""Request body for updating fusion config."""
w_rule: float | None = None
w_semantic: float | None = None
w_llm: float | None = None
semantic_threshold: float | None = None
conflict_threshold: float | None = None
gray_zone_threshold: float | None = None
min_trigger_threshold: float | None = None
clarify_threshold: float | None = None
multi_intent_threshold: float | None = None
llm_judge_enabled: bool | None = None
semantic_matcher_enabled: bool | None = None
semantic_matcher_timeout_ms: int | None = None
llm_judge_timeout_ms: int | None = None
semantic_top_k: int | None = None
@router.get("/fusion-config")
async def get_fusion_config() -> dict[str, Any]:
"""
[AC-AISVC-116] Get current fusion configuration.
"""
logger.info("[AC-AISVC-116] Getting fusion config")
return _fusion_config.to_dict()
@router.put("/fusion-config")
async def update_fusion_config(
body: FusionConfigUpdate,
) -> dict[str, Any]:
"""
[AC-AISVC-116] Update fusion configuration.
"""
global _fusion_config
logger.info(f"[AC-AISVC-116] Updating fusion config: {body.model_dump()}")
current_dict = _fusion_config.to_dict()
update_dict = body.model_dump(exclude_none=True)
current_dict.update(update_dict)
_fusion_config = FusionConfig.from_dict(current_dict)
return _fusion_config.to_dict()
@router.post("/{rule_id}/generate-vector")
async def generate_intent_vector(
rule_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-114] Generate intent vector for a rule.
Uses the rule's semantic_examples to generate an average vector.
If no semantic_examples exist, returns an error.
"""
logger.info(
f"[AC-AISVC-114] Generating intent vector for tenant={tenant_id}, rule_id={rule_id}"
)
service = IntentRuleService(session)
rule = await service.get_rule(tenant_id, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Intent rule not found")
if not rule.semantic_examples:
raise HTTPException(
status_code=400,
detail="Rule has no semantic_examples. Please add semantic_examples first."
)
try:
from app.core.dependencies import get_embedding_provider
embedding_provider = get_embedding_provider()
vectors = await embedding_provider.embed_batch(rule.semantic_examples)
import numpy as np
avg_vector = np.mean(vectors, axis=0).tolist()
update_data = IntentRuleUpdate(intent_vector=avg_vector)
updated_rule = await service.update_rule(tenant_id, rule_id, update_data)
logger.info(
f"[AC-AISVC-114] Generated intent vector for rule={rule_id}, "
f"dimension={len(avg_vector)}"
)
return {
"id": str(updated_rule.id),
"intent_vector": updated_rule.intent_vector,
"semantic_examples": updated_rule.semantic_examples,
}
except Exception as e:
logger.error(f"[AC-AISVC-114] Failed to generate intent vector: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to generate intent vector: {str(e)}"
)

View File

@ -6,28 +6,35 @@ Knowledge Base management endpoints.
import logging
import uuid
import json
import hashlib
from dataclasses import dataclass
from typing import Annotated, Any, Optional
import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Query, UploadFile
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import get_settings
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import (
Document,
DocumentStatus,
IndexJob,
IndexJobStatus,
KBType,
KnowledgeBase,
KnowledgeBaseCreate,
KnowledgeBaseUpdate,
)
from app.services.kb import KBService
from app.services.knowledge_base_service import KnowledgeBaseService
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger = logging.getLogger(__name__)
@ -457,6 +464,7 @@ async def list_documents(
"kbId": doc.kb_id,
"fileName": doc.file_name,
"status": doc.status,
"metadata": doc.doc_metadata,
"jobId": str(latest_job.id) if latest_job else None,
"createdAt": doc.created_at.isoformat() + "Z",
"updatedAt": doc.updated_at.isoformat() + "Z",
@ -585,6 +593,7 @@ async def upload_document(
file_name=file.filename or "unknown",
file_content=file_content,
file_type=file.content_type,
metadata=metadata_dict,
)
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
@ -915,3 +924,488 @@ async def delete_document(
"message": "Document deleted",
}
)
@router.put(
"/documents/{doc_id}/metadata",
operation_id="updateDocumentMetadata",
summary="Update document metadata",
description="[AC-ASA-08] Update metadata for a specific document.",
responses={
200: {"description": "Metadata updated"},
404: {"description": "Document not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def update_document_metadata(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
doc_id: str,
body: dict,
) -> JSONResponse:
"""
[AC-ASA-08] Update document metadata.
"""
import json
metadata = body.get("metadata")
if metadata is not None and not isinstance(metadata, dict):
try:
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_METADATA",
"message": "Invalid JSON format for metadata",
},
)
logger.info(
f"[AC-ASA-08] Updating document metadata: tenant={tenant_id}, doc_id={doc_id}"
)
from sqlalchemy import select
from app.models.entities import Document
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == doc_id,
)
result = await session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
return JSONResponse(
status_code=404,
content={
"code": "DOCUMENT_NOT_FOUND",
"message": f"Document {doc_id} not found",
},
)
document.doc_metadata = metadata
await session.commit()
return JSONResponse(
content={
"success": True,
"message": "Metadata updated",
"metadata": document.doc_metadata,
}
)
@router.post(
"/documents/batch-upload",
operation_id="batchUploadDocuments",
summary="Batch upload documents from zip",
description="Upload a zip file containing multiple folders, each with a markdown file and metadata.json",
responses={
200: {"description": "Batch upload result"},
400: {"description": "Bad Request - invalid zip or missing files"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def batch_upload_documents(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
kb_id: str = Form(...),
) -> JSONResponse:
"""
Batch upload documents from a zip file.
Zip structure:
- Each folder contains one .md file and one metadata.json
- metadata.json uses field_key from MetadataFieldDefinition as keys
Example metadata.json:
{
"grade": "高一",
"subject": "数学",
"type": "痛点"
}
"""
import json
import tempfile
import zipfile
from pathlib import Path
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger.info(
f"[BATCH-UPLOAD] Starting batch upload: tenant={tenant_id}, "
f"kb_id={kb_id}, filename={file.filename}"
)
if not file.filename or not file.filename.lower().endswith('.zip'):
return JSONResponse(
status_code=400,
content={
"code": "INVALID_FORMAT",
"message": "Only .zip files are supported",
},
)
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
if not kb:
return JSONResponse(
status_code=404,
content={
"code": "KB_NOT_FOUND",
"message": f"Knowledge base {kb_id} not found",
},
)
file_content = await file.read()
results = []
succeeded = 0
failed = 0
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = Path(temp_dir) / "upload.zip"
with open(zip_path, "wb") as f:
f.write(file_content)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(temp_dir)
except zipfile.BadZipFile as e:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_ZIP",
"message": f"Invalid zip file: {str(e)}",
},
)
extracted_path = Path(temp_dir)
# 列出解压后的所有内容,用于调试
all_items = list(extracted_path.iterdir())
logger.info(f"[BATCH-UPLOAD] Extracted items: {[item.name for item in all_items]}")
# 递归查找所有包含 content.txt/md 和 metadata.json 的文件夹
def find_document_folders(path: Path) -> list[Path]:
"""递归查找所有包含文档文件的文件夹"""
doc_folders = []
# 检查当前文件夹是否包含文档文件
content_files = (
list(path.glob("*.md")) +
list(path.glob("*.markdown")) +
list(path.glob("*.txt"))
)
if content_files:
# 这个文件夹包含文档文件,是一个文档文件夹
doc_folders.append(path)
logger.info(f"[BATCH-UPLOAD] Found document folder: {path.name}, files: {[f.name for f in content_files]}")
# 递归检查子文件夹
for subfolder in [p for p in path.iterdir() if p.is_dir()]:
doc_folders.extend(find_document_folders(subfolder))
return doc_folders
folders = find_document_folders(extracted_path)
if not folders:
logger.error(f"[BATCH-UPLOAD] No document folders found in zip. Items found: {[item.name for item in all_items]}")
return JSONResponse(
status_code=400,
content={
"code": "NO_DOCUMENTS_FOUND",
"message": "压缩包中没有找到包含 .txt/.md 文件的文件夹",
"details": {
"expected_structure": "每个文件夹应包含 content.txt (或 .md) 和 metadata.json",
"found_items": [item.name for item in all_items],
},
},
)
logger.info(f"[BATCH-UPLOAD] Found {len(folders)} document folders")
for folder in folders:
folder_name = folder.name if folder != extracted_path else "root"
content_files = (
list(folder.glob("*.md")) +
list(folder.glob("*.markdown")) +
list(folder.glob("*.txt"))
)
if not content_files:
# 这种情况不应该发生,因为我们已经过滤过了
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": "No content file found",
})
continue
content_file = content_files[0]
metadata_file = folder / "metadata.json"
metadata_dict = {}
if metadata_file.exists():
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata_dict = json.load(f)
except json.JSONDecodeError as e:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": f"Invalid metadata.json: {str(e)}",
})
continue
else:
logger.warning(f"[BATCH-UPLOAD] No metadata.json in folder {folder_name}, using empty metadata")
field_def_service = MetadataFieldDefinitionService(session)
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, metadata_dict, "kb_document"
)
if not is_valid:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": f"Metadata validation failed: {validation_errors}",
})
continue
try:
with open(content_file, 'rb') as f:
doc_content = f.read()
file_ext = content_file.suffix.lower()
if file_ext == '.txt':
file_type = "text/plain"
else:
file_type = "text/markdown"
doc_kb_service = KBService(session)
document, job = await doc_kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=content_file.name,
file_content=doc_content,
file_type=file_type,
metadata=metadata_dict,
)
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
await session.commit()
background_tasks.add_task(
_index_document,
tenant_id,
kb_id,
str(job.id),
str(document.id),
doc_content,
content_file.name,
metadata_dict,
)
succeeded += 1
results.append({
"folder": folder_name,
"docId": str(document.id),
"jobId": str(job.id),
"status": "created",
"fileName": content_file.name,
})
logger.info(
f"[BATCH-UPLOAD] Created document: folder={folder_name}, "
f"doc_id={document.id}, job_id={job.id}"
)
except Exception as e:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": str(e),
})
logger.error(f"[BATCH-UPLOAD] Failed to create document: folder={folder_name}, error={e}")
logger.info(
f"[BATCH-UPLOAD] Completed: total={len(results)}, succeeded={succeeded}, failed={failed}"
)
return JSONResponse(
content={
"success": True,
"total": len(results),
"succeeded": succeeded,
"failed": failed,
"results": results,
}
)
@router.post(
"/{kb_id}/documents/json-batch",
summary="[AC-KB-03] JSON批量上传文档",
description="上传JSONL格式文件每行一个JSON对象包含text和元数据字段",
)
async def upload_json_batch(
kb_id: str,
tenant_id: str = Query(..., description="租户ID"),
file: UploadFile = File(..., description="JSONL格式文件每行一个JSON对象"),
session: AsyncSession = Depends(get_session),
background_tasks: BackgroundTasks = None,
):
"""
JSON批量上传文档
文件格式JSONL (每行一个JSON对象)
必填字段text - 需要录入知识库的文本内容
可选字段元数据字段如grade, subject, kb_scene等
示例
{"text": "课程内容...", "grade": "初二", "subject": "数学", "kb_scene": "课程咨询"}
{"text": "另一条课程内容...", "grade": "初三", "info_type": "课程概述"}
"""
kb = await session.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
if kb.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="无权访问此知识库")
valid_field_keys = set()
try:
field_defs = await MetadataFieldDefinitionService(session).get_fields(
tenant_id=tenant_id,
include_inactive=False,
)
valid_field_keys = {f.field_key for f in field_defs}
logger.info(f"[AC-KB-03] Valid metadata fields for tenant {tenant_id}: {valid_field_keys}")
except Exception as e:
logger.warning(f"[AC-KB-03] Failed to get metadata fields: {e}")
content = await file.read()
try:
text_content = content.decode("utf-8")
except UnicodeDecodeError:
try:
text_content = content.decode("gbk")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="文件编码不支持请使用UTF-8编码")
lines = text_content.strip().split("\n")
if not lines:
raise HTTPException(status_code=400, detail="文件内容为空")
results = []
succeeded = 0
failed = 0
kb_service = KBService(session)
for line_num, line in enumerate(lines, 1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
except json.JSONDecodeError as e:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": f"JSON解析失败: {e}",
})
continue
text = json_obj.get("text")
if not text:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": "缺少必填字段: text",
})
continue
metadata = {}
for key, value in json_obj.items():
if key == "text":
continue
if valid_field_keys and key not in valid_field_keys:
logger.debug(f"[AC-KB-03] Skipping invalid metadata field: {key}")
continue
if value is not None:
metadata[key] = value
try:
file_name = f"json_batch_line_{line_num}.txt"
file_content = text.encode("utf-8")
document, job = await kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=file_name,
file_content=file_content,
file_type="text/plain",
metadata=metadata,
)
if background_tasks:
background_tasks.add_task(
_index_document,
tenant_id,
kb_id,
str(job.id),
str(document.id),
file_content,
file_name,
metadata,
)
succeeded += 1
results.append({
"line": line_num,
"success": True,
"doc_id": str(document.id),
"job_id": str(job.id),
"metadata": metadata,
})
except Exception as e:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": str(e),
})
logger.error(f"[AC-KB-03] Failed to upload document at line {line_num}: {e}")
await session.commit()
logger.info(f"[AC-KB-03] JSON batch upload completed: kb_id={kb_id}, total={len(lines)}, succeeded={succeeded}, failed={failed}")
return JSONResponse(
content={
"success": True,
"total": len(lines),
"succeeded": succeeded,
"failed": failed,
"valid_metadata_fields": list(valid_field_keys) if valid_field_keys else [],
"results": results,
}
)

View File

@ -51,6 +51,7 @@ def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
"scope": f.scope,
"is_filterable": f.is_filterable,
"is_rank_feature": f.is_rank_feature,
"usage_description": f.usage_description,
"field_roles": f.field_roles or [],
"status": f.status,
"version": f.version,

View File

@ -407,6 +407,7 @@ async def get_conversation_detail(
"guardrailTriggered": user_msg.guardrail_triggered,
"guardrailWords": user_msg.guardrail_words,
"executionSteps": execution_steps,
"routeTrace": user_msg.route_trace,
"createdAt": user_msg.created_at.isoformat(),
}
@ -659,8 +660,56 @@ async def _process_export(
except Exception as e:
logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}")
task = await session.get(ExportTask, task_id)
task = task_status.get(ExportTask, task_id)
if task:
task.status = ExportTaskStatus.FAILED.value
task.error_message = str(e)
await session.commit()
@router.get("/clarification-metrics")
async def get_clarification_metrics(
tenant_id: str = Depends(get_tenant_id),
total_requests: int = Query(100, ge=1, description="Total requests for rate calculation"),
) -> dict[str, Any]:
"""
[AC-CLARIFY] Get clarification metrics.
Returns:
- clarify_trigger_rate: 澄清触发率
- clarify_converge_rate: 澄清后收敛率
- misroute_rate: 误入流程率
"""
from app.services.intent.clarification import get_clarify_metrics
metrics = get_clarify_metrics()
counts = metrics.get_metrics()
rates = metrics.get_rates(total_requests)
return {
"counts": counts,
"rates": rates,
"thresholds": {
"t_high": 0.75,
"t_low": 0.45,
"max_retry": 3,
},
}
@router.post("/clarification-metrics/reset")
async def reset_clarification_metrics(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
[AC-CLARIFY] Reset clarification metrics.
"""
from app.services.intent.clarification import get_clarify_metrics
metrics = get_clarify_metrics()
metrics.reset()
return {
"status": "reset",
"message": "Clarification metrics have been reset.",
}

View File

@ -17,6 +17,17 @@ from app.models.entities import PromptTemplateCreate, PromptTemplateUpdate
from app.services.prompt.template_service import PromptTemplateService
from app.services.monitoring.prompt_monitor import PromptMonitor
class PromptTemplateUpdateAPI(BaseModel):
"""API model for updating prompt template with metadata field mapping."""
name: str | None = None
scene: str | None = None
description: str | None = None
system_instruction: str | None = None
variables: list[dict[str, Any]] | None = None
is_default: bool | None = None
metadata: dict[str, Any] | None = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/prompt-templates", tags=["Prompt Management"])
@ -108,29 +119,47 @@ async def get_template_detail(
@router.put("/{tpl_id}")
async def update_template(
tpl_id: uuid.UUID,
body: PromptTemplateUpdate,
body: PromptTemplateUpdateAPI,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-53] Update prompt template (creates a new version).
[AC-IDSMETA-16] Return current_content and metadata for frontend display.
"""
logger.info(f"[AC-AISVC-53] Updating template for tenant={tenant_id}, id={tpl_id}")
service = PromptTemplateService(session)
template = await service.update_template(tenant_id, tpl_id, body)
# Convert API model to entity model (metadata -> metadata_)
update_data = PromptTemplateUpdate(
name=body.name,
scene=body.scene,
description=body.description,
system_instruction=body.system_instruction,
variables=body.variables,
is_default=body.is_default,
metadata_=body.metadata,
)
template = await service.update_template(tenant_id, tpl_id, update_data)
if not template:
raise HTTPException(status_code=404, detail="Template not found")
published_version = await service.get_published_version_info(tenant_id, template.id)
# Get latest version content for current_content
latest_version = await service._get_latest_version(template.id)
return {
"id": str(template.id),
"name": template.name,
"scene": template.scene,
"description": template.description,
"is_default": template.is_default,
"current_content": latest_version.system_instruction if latest_version else None,
"metadata": template.metadata_,
"published_version": published_version,
"created_at": template.created_at.isoformat(),
"updated_at": template.updated_at.isoformat(),

View File

@ -0,0 +1,265 @@
"""
Scene Slot Bundle API.
[AC-SCENE-SLOT-01] 场景-槽位映射配置管理接口
"""
import logging
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models.entities import SceneSlotBundleCreate, SceneSlotBundleUpdate
from app.services.scene_slot_bundle_service import SceneSlotBundleService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/scene-slot-bundles", tags=["SceneSlotBundle"])
def get_current_tenant_id() -> str:
"""Get current tenant ID from context."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
def _bundle_to_dict(bundle: Any) -> dict[str, Any]:
"""Convert bundle to dict"""
return {
"id": str(bundle.id),
"tenant_id": str(bundle.tenant_id),
"scene_key": bundle.scene_key,
"scene_name": bundle.scene_name,
"description": bundle.description,
"required_slots": bundle.required_slots,
"optional_slots": bundle.optional_slots,
"slot_priority": bundle.slot_priority,
"completion_threshold": bundle.completion_threshold,
"ask_back_order": bundle.ask_back_order,
"status": bundle.status,
"version": bundle.version,
"created_at": bundle.created_at.isoformat() if bundle.created_at else None,
"updated_at": bundle.updated_at.isoformat() if bundle.updated_at else None,
}
@router.get(
"",
operation_id="listSceneSlotBundles",
summary="List scene slot bundles",
description="获取场景槽位包列表",
)
async def list_scene_slot_bundles(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
status: Annotated[str | None, Query(
description="按状态过滤: draft/active/deprecated"
)] = None,
) -> JSONResponse:
"""
列出场景槽位包
"""
logger.info(
f"Listing scene slot bundles: tenant={tenant_id}, status={status}"
)
service = SceneSlotBundleService(session)
bundles = await service.list_bundles(tenant_id, status)
return JSONResponse(
content=[_bundle_to_dict(b) for b in bundles]
)
@router.post(
"",
operation_id="createSceneSlotBundle",
summary="Create scene slot bundle",
description="[AC-SCENE-SLOT-01] 创建新的场景槽位包",
status_code=201,
)
async def create_scene_slot_bundle(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
bundle_create: SceneSlotBundleCreate,
) -> JSONResponse:
"""
[AC-SCENE-SLOT-01] 创建场景槽位包
"""
logger.info(
f"[AC-SCENE-SLOT-01] Creating scene slot bundle: "
f"tenant={tenant_id}, scene_key={bundle_create.scene_key}"
)
service = SceneSlotBundleService(session)
try:
bundle = await service.create_bundle(tenant_id, bundle_create)
await session.commit()
except ValueError as e:
return JSONResponse(
status_code=400,
content={
"error_code": "VALIDATION_ERROR",
"message": str(e),
}
)
return JSONResponse(
status_code=201,
content=_bundle_to_dict(bundle)
)
@router.get(
"/by-scene/{scene_key}",
operation_id="getSceneSlotBundleBySceneKey",
summary="Get scene slot bundle by scene key",
description="根据场景标识获取槽位包",
)
async def get_scene_slot_bundle_by_scene_key(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
scene_key: str,
) -> JSONResponse:
"""
根据场景标识获取槽位包
"""
logger.info(
f"Getting scene slot bundle by scene_key: tenant={tenant_id}, scene_key={scene_key}"
)
service = SceneSlotBundleService(session)
bundle = await service.get_bundle_by_scene_key(tenant_id, scene_key)
if not bundle:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Scene slot bundle with scene_key '{scene_key}' not found",
}
)
return JSONResponse(content=_bundle_to_dict(bundle))
@router.get(
"/{id}",
operation_id="getSceneSlotBundle",
summary="Get scene slot bundle by ID",
description="获取单个场景槽位包详情(含槽位详情)",
)
async def get_scene_slot_bundle(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
) -> JSONResponse:
"""
获取单个场景槽位包详情
"""
logger.info(
f"Getting scene slot bundle: tenant={tenant_id}, id={id}"
)
service = SceneSlotBundleService(session)
bundle = await service.get_bundle_with_slot_details(tenant_id, id)
if not bundle:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Scene slot bundle {id} not found",
}
)
return JSONResponse(content=bundle)
@router.put(
"/{id}",
operation_id="updateSceneSlotBundle",
summary="Update scene slot bundle",
description="更新场景槽位包",
)
async def update_scene_slot_bundle(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
bundle_update: SceneSlotBundleUpdate,
) -> JSONResponse:
"""
更新场景槽位包
"""
logger.info(
f"Updating scene slot bundle: tenant={tenant_id}, id={id}"
)
service = SceneSlotBundleService(session)
try:
bundle = await service.update_bundle(tenant_id, id, bundle_update)
except ValueError as e:
return JSONResponse(
status_code=400,
content={
"error_code": "VALIDATION_ERROR",
"message": str(e),
}
)
if not bundle:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Scene slot bundle {id} not found",
}
)
await session.commit()
return JSONResponse(content=_bundle_to_dict(bundle))
@router.delete(
"/{id}",
operation_id="deleteSceneSlotBundle",
summary="Delete scene slot bundle",
description="[AC-SCENE-SLOT-01] 删除场景槽位包",
status_code=204,
)
async def delete_scene_slot_bundle(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
) -> JSONResponse:
"""
[AC-SCENE-SLOT-01] 删除场景槽位包
"""
logger.info(
f"[AC-SCENE-SLOT-01] Deleting scene slot bundle: tenant={tenant_id}, id={id}"
)
service = SceneSlotBundleService(session)
success = await service.delete_bundle(tenant_id, id)
if not success:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Scene slot bundle not found: {id}",
}
)
await session.commit()
return JSONResponse(status_code=204, content=None)

View File

@ -38,9 +38,13 @@ def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]:
"id": str(slot.id),
"tenant_id": str(slot.tenant_id),
"slot_key": slot.slot_key,
"display_name": slot.display_name,
"description": slot.description,
"type": slot.type,
"required": slot.required,
# [AC-MRS-07-UPGRADE] 返回新旧字段
"extract_strategy": slot.extract_strategy,
"extract_strategies": slot.extract_strategies,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,

View File

@ -0,0 +1,342 @@
"""
Retrieval Strategy API Endpoints.
[AC-AISVC-RES-01~15] API for strategy management and configuration.
Endpoints:
- GET /strategy/retrieval/current - Get current strategy configuration
- POST /strategy/retrieval/switch - Switch strategy configuration
- POST /strategy/retrieval/validate - Validate strategy configuration
- POST /strategy/retrieval/rollback - Rollback to default strategy
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.tenant import get_tenant_id
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
)
from app.services.retrieval.strategy_router import (
get_strategy_router,
RollbackRecord,
)
from app.services.retrieval.mode_router import get_mode_router
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/strategy/retrieval", tags=["Retrieval Strategy"])
class RoutingConfigRequest(BaseModel):
"""Request model for routing configuration."""
enabled: bool | None = Field(default=None, description="Enable strategy routing")
strategy: StrategyType | None = Field(default=None, description="Retrieval strategy")
grayscale_percentage: float | None = Field(default=None, ge=0.0, le=1.0, description="Grayscale percentage")
grayscale_allowlist: list[str] | None = Field(default=None, description="Grayscale allowlist")
rag_runtime_mode: RagRuntimeMode | None = Field(default=None, description="RAG runtime mode")
react_trigger_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
react_trigger_complexity_score: float | None = Field(default=None, ge=0.0, le=1.0)
react_max_steps: int | None = Field(default=None, ge=3, le=10)
direct_fallback_on_low_confidence: bool | None = Field(default=None)
direct_fallback_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
performance_budget_ms: int | None = Field(default=None, ge=1000)
performance_degradation_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
class RoutingConfigResponse(BaseModel):
"""Response model for routing configuration."""
enabled: bool
strategy: StrategyType
grayscale_percentage: float
grayscale_allowlist: list[str]
rag_runtime_mode: RagRuntimeMode
react_trigger_confidence_threshold: float
react_trigger_complexity_score: float
react_max_steps: int
direct_fallback_on_low_confidence: bool
direct_fallback_confidence_threshold: float
performance_budget_ms: int
performance_degradation_threshold: float
class ValidationResponse(BaseModel):
"""Response model for configuration validation."""
is_valid: bool
errors: list[str]
warnings: list[str] = Field(default_factory=list)
class RollbackRequest(BaseModel):
"""Request model for strategy rollback."""
reason: str = Field(..., description="Reason for rollback")
tenant_id: str | None = Field(default=None, description="Optional tenant ID for audit")
class RollbackResponse(BaseModel):
"""Response model for strategy rollback."""
success: bool
previous_strategy: StrategyType
current_strategy: StrategyType
reason: str
rollback_records: list[dict[str, Any]] = Field(default_factory=list)
class RollbackRecordResponse(BaseModel):
"""Response model for rollback record."""
timestamp: float
from_strategy: StrategyType
to_strategy: StrategyType
reason: str
tenant_id: str | None
request_id: str | None
class CurrentStrategyResponse(BaseModel):
"""Response model for current strategy."""
config: RoutingConfigResponse
current_strategy: StrategyType
rollback_records: list[RollbackRecordResponse] = Field(default_factory=list)
@router.get(
"/current",
operation_id="getCurrentRetrievalStrategy",
summary="Get current retrieval strategy configuration",
description="[AC-AISVC-RES-01] Returns the current strategy configuration and recent rollback records.",
response_model=CurrentStrategyResponse,
)
async def get_current_strategy() -> CurrentStrategyResponse:
"""
[AC-AISVC-RES-01] Get current retrieval strategy configuration.
"""
strategy_router = get_strategy_router()
config = strategy_router.config
rollback_records = strategy_router.get_rollback_records(limit=5)
return CurrentStrategyResponse(
config=RoutingConfigResponse(
enabled=config.enabled,
strategy=config.strategy,
grayscale_percentage=config.grayscale_percentage,
grayscale_allowlist=config.grayscale_allowlist,
rag_runtime_mode=config.rag_runtime_mode,
react_trigger_confidence_threshold=config.react_trigger_confidence_threshold,
react_trigger_complexity_score=config.react_trigger_complexity_score,
react_max_steps=config.react_max_steps,
direct_fallback_on_low_confidence=config.direct_fallback_on_low_confidence,
direct_fallback_confidence_threshold=config.direct_fallback_confidence_threshold,
performance_budget_ms=config.performance_budget_ms,
performance_degradation_threshold=config.performance_degradation_threshold,
),
current_strategy=strategy_router.current_strategy,
rollback_records=[
RollbackRecordResponse(
timestamp=r.timestamp,
from_strategy=r.from_strategy,
to_strategy=r.to_strategy,
reason=r.reason,
tenant_id=r.tenant_id,
request_id=r.request_id,
)
for r in rollback_records
],
)
@router.post(
"/switch",
operation_id="switchRetrievalStrategy",
summary="Switch retrieval strategy configuration",
description="[AC-AISVC-RES-02, AC-AISVC-RES-03] Update strategy configuration with hot reload support.",
response_model=RoutingConfigResponse,
)
async def switch_strategy(
request: RoutingConfigRequest,
session: AsyncSession = Depends(get_session),
) -> RoutingConfigResponse:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-15]
Switch retrieval strategy configuration.
Supports:
- Strategy selection (default/enhanced)
- Grayscale release (percentage/allowlist)
- Mode selection (direct/react/auto)
- Hot reload of routing parameters
"""
strategy_router = get_strategy_router()
mode_router = get_mode_router()
current_config = strategy_router.config
new_config = RoutingConfig(
enabled=request.enabled if request.enabled is not None else current_config.enabled,
strategy=request.strategy if request.strategy is not None else current_config.strategy,
grayscale_percentage=request.grayscale_percentage if request.grayscale_percentage is not None else current_config.grayscale_percentage,
grayscale_allowlist=request.grayscale_allowlist if request.grayscale_allowlist is not None else current_config.grayscale_allowlist,
rag_runtime_mode=request.rag_runtime_mode if request.rag_runtime_mode is not None else current_config.rag_runtime_mode,
react_trigger_confidence_threshold=request.react_trigger_confidence_threshold if request.react_trigger_confidence_threshold is not None else current_config.react_trigger_confidence_threshold,
react_trigger_complexity_score=request.react_trigger_complexity_score if request.react_trigger_complexity_score is not None else current_config.react_trigger_complexity_score,
react_max_steps=request.react_max_steps if request.react_max_steps is not None else current_config.react_max_steps,
direct_fallback_on_low_confidence=request.direct_fallback_on_low_confidence if request.direct_fallback_on_low_confidence is not None else current_config.direct_fallback_on_low_confidence,
direct_fallback_confidence_threshold=request.direct_fallback_confidence_threshold if request.direct_fallback_confidence_threshold is not None else current_config.direct_fallback_confidence_threshold,
performance_budget_ms=request.performance_budget_ms if request.performance_budget_ms is not None else current_config.performance_budget_ms,
performance_degradation_threshold=request.performance_degradation_threshold if request.performance_degradation_threshold is not None else current_config.performance_degradation_threshold,
)
is_valid, errors = new_config.validate()
if not is_valid:
raise HTTPException(
status_code=400,
detail={"errors": errors},
)
strategy_router.update_config(new_config)
mode_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-02, AC-AISVC-RES-15] Strategy switched: "
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
)
return RoutingConfigResponse(
enabled=new_config.enabled,
strategy=new_config.strategy,
grayscale_percentage=new_config.grayscale_percentage,
grayscale_allowlist=new_config.grayscale_allowlist,
rag_runtime_mode=new_config.rag_runtime_mode,
react_trigger_confidence_threshold=new_config.react_trigger_confidence_threshold,
react_trigger_complexity_score=new_config.react_trigger_complexity_score,
react_max_steps=new_config.react_max_steps,
direct_fallback_on_low_confidence=new_config.direct_fallback_on_low_confidence,
direct_fallback_confidence_threshold=new_config.direct_fallback_confidence_threshold,
performance_budget_ms=new_config.performance_budget_ms,
performance_degradation_threshold=new_config.performance_degradation_threshold,
)
@router.post(
"/validate",
operation_id="validateRetrievalStrategy",
summary="Validate retrieval strategy configuration",
description="[AC-AISVC-RES-06] Validate strategy configuration for completeness and consistency.",
response_model=ValidationResponse,
)
async def validate_strategy(
request: RoutingConfigRequest,
) -> ValidationResponse:
"""
[AC-AISVC-RES-06] Validate strategy configuration.
Checks:
- Parameter value ranges
- Configuration consistency
- Performance budget constraints
"""
warnings = []
config = RoutingConfig(
enabled=request.enabled if request.enabled is not None else True,
strategy=request.strategy if request.strategy is not None else StrategyType.DEFAULT,
grayscale_percentage=request.grayscale_percentage if request.grayscale_percentage is not None else 0.0,
grayscale_allowlist=request.grayscale_allowlist or [],
rag_runtime_mode=request.rag_runtime_mode if request.rag_runtime_mode is not None else RagRuntimeMode.AUTO,
react_trigger_confidence_threshold=request.react_trigger_confidence_threshold if request.react_trigger_confidence_threshold is not None else 0.6,
react_trigger_complexity_score=request.react_trigger_complexity_score if request.react_trigger_complexity_score is not None else 0.5,
react_max_steps=request.react_max_steps if request.react_max_steps is not None else 5,
direct_fallback_on_low_confidence=request.direct_fallback_on_low_confidence if request.direct_fallback_on_low_confidence is not None else True,
direct_fallback_confidence_threshold=request.direct_fallback_confidence_threshold if request.direct_fallback_confidence_threshold is not None else 0.4,
performance_budget_ms=request.performance_budget_ms if request.performance_budget_ms is not None else 5000,
performance_degradation_threshold=request.performance_degradation_threshold if request.performance_degradation_threshold is not None else 0.2,
)
is_valid, errors = config.validate()
if config.grayscale_percentage > 0.5:
warnings.append("grayscale_percentage > 50% may affect production stability")
if config.react_max_steps > 7:
warnings.append("react_max_steps > 7 may cause high latency")
if config.direct_fallback_confidence_threshold > config.react_trigger_confidence_threshold:
warnings.append(
"direct_fallback_confidence_threshold > react_trigger_confidence_threshold "
"may cause frequent fallbacks"
)
if config.performance_budget_ms < 3000:
warnings.append("performance_budget_ms < 3000ms may be too aggressive for complex queries")
logger.info(
f"[AC-AISVC-RES-06] Strategy validation: is_valid={is_valid}, "
f"errors={len(errors)}, warnings={len(warnings)}"
)
return ValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=warnings,
)
@router.post(
"/rollback",
operation_id="rollbackRetrievalStrategy",
summary="Rollback to default retrieval strategy",
description="[AC-AISVC-RES-07] Force rollback to default strategy with audit logging.",
response_model=RollbackResponse,
)
async def rollback_strategy(
request: RollbackRequest,
session: AsyncSession = Depends(get_session),
) -> RollbackResponse:
"""
[AC-AISVC-RES-07] Rollback to default strategy.
Records the rollback event for audit and monitoring.
"""
strategy_router = get_strategy_router()
mode_router = get_mode_router()
previous_strategy = strategy_router.current_strategy
strategy_router.rollback(
reason=request.reason,
tenant_id=request.tenant_id,
)
new_config = RoutingConfig(
strategy=StrategyType.DEFAULT,
rag_runtime_mode=RagRuntimeMode.AUTO,
)
mode_router.update_config(new_config)
rollback_records = strategy_router.get_rollback_records(limit=5)
logger.info(
f"[AC-AISVC-RES-07] Strategy rollback: "
f"from={previous_strategy.value}, to=DEFAULT, reason={request.reason}"
)
return RollbackResponse(
success=True,
previous_strategy=previous_strategy,
current_strategy=StrategyType.DEFAULT,
reason=request.reason,
rollback_records=[
{
"timestamp": r.timestamp,
"from_strategy": r.from_strategy.value,
"to_strategy": r.to_strategy.value,
"reason": r.reason,
"tenant_id": r.tenant_id,
}
for r in rollback_records
],
)

View File

@ -55,6 +55,18 @@ from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.trace_logger import TraceLogger
from app.services.prompt.template_service import PromptTemplateService
from app.services.prompt.variable_resolver import VariableResolver
from app.services.intent.clarification import (
ClarificationEngine,
ClarifyReason,
ClarifySessionManager,
ClarifyState,
HybridIntentResult,
IntentCandidate,
T_HIGH,
T_LOW,
)
logger = logging.getLogger(__name__)
@ -118,6 +130,7 @@ _kb_search_dynamic_registered: bool = False
_intent_hint_registered: bool = False
_high_risk_check_registered: bool = False
_memory_recall_registered: bool = False
_metadata_discovery_registered: bool = False
def ensure_kb_search_dynamic_registered(
@ -134,7 +147,7 @@ def ensure_kb_search_dynamic_registered(
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=2000,
timeout_ms=10000,
min_score_threshold=0.5,
)
@ -222,6 +235,26 @@ def ensure_memory_recall_registered(
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
def ensure_metadata_discovery_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""Ensure metadata_discovery tool is registered."""
global _metadata_discovery_registered
if _metadata_discovery_registered:
return
from app.services.mid.metadata_discovery_tool import register_metadata_discovery_tool
register_metadata_discovery_tool(
registry=registry,
session=session,
timeout_governor=get_timeout_governor(),
)
_metadata_discovery_registered = True
logger.info("[MetadataDiscovery] metadata_discovery tool registered to registry")
def get_output_guardrail_executor() -> OutputGuardrailExecutor:
"""Get or create OutputGuardrailExecutor instance."""
if "output_guardrail_executor" not in _mid_services:
@ -259,6 +292,16 @@ def get_runtime_observer() -> RuntimeObserver:
return _mid_services["runtime_observer"]
def get_clarification_engine() -> ClarificationEngine:
"""Get or create ClarificationEngine instance."""
if "clarification_engine" not in _mid_services:
_mid_services["clarification_engine"] = ClarificationEngine(
t_high=T_HIGH,
t_low=T_LOW,
)
return _mid_services["clarification_engine"]
@router.post(
"/dialogue/respond",
operation_id="respondDialogue",
@ -288,6 +331,7 @@ async def respond_dialogue(
default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner),
segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer),
runtime_observer: RuntimeObserver = Depends(get_runtime_observer),
clarification_engine: ClarificationEngine = Depends(get_clarification_engine),
) -> DialogueResponse:
"""
[AC-MARH-01~12] Generate dialogue response with segments and trace.
@ -316,7 +360,8 @@ async def respond_dialogue(
logger.info(
f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, "
f"session={dialogue_request.session_id}, request_id={request_id}"
f"session={dialogue_request.session_id}, request_id={request_id}, "
f"user_message={dialogue_request.user_message[:100] if dialogue_request.user_message else 'None'}..."
)
runtime_ctx = runtime_observer.start_observation(
@ -340,6 +385,7 @@ async def respond_dialogue(
ensure_intent_hint_registered(tool_registry, session)
ensure_high_risk_check_registered(tool_registry, session)
ensure_memory_recall_registered(tool_registry, session)
ensure_metadata_discovery_registered(tool_registry, session)
try:
interrupt_ctx = interrupt_context_enricher.enrich(
@ -541,7 +587,15 @@ async def respond_dialogue(
except Exception as e:
latency_ms = int((time.time() - start_time) * 1000)
logger.error(f"[AC-IDMP-06] Dialogue error: {e}")
import traceback
logger.error(
f"[AC-IDMP-06] Dialogue error: {e}\n"
f"Exception type: {type(e).__name__}\n"
f"Request details: session_id={dialogue_request.session_id}, "
f"user_message={dialogue_request.user_message[:200] if dialogue_request.user_message else 'None'}, "
f"scene={dialogue_request.scene}, user_id={dialogue_request.user_id}\n"
f"Traceback:\n{traceback.format_exc()}"
)
trace_logger.update_trace(
request_id=request_id,
@ -593,7 +647,7 @@ async def _match_intent(
return IntentMatch(
intent_id=str(result.rule.id),
intent_name=result.rule.name,
confidence=0.8,
confidence=1.0,
response_type=result.rule.response_type,
target_kb_ids=result.rule.target_kb_ids,
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
@ -608,6 +662,227 @@ async def _match_intent(
return None
async def _match_intent_hybrid(
tenant_id: str,
request: DialogueRequest,
session: AsyncSession,
clarification_engine: ClarificationEngine,
) -> HybridIntentResult:
"""
[AC-CLARIFY] Hybrid intent matching with clarification support.
Returns HybridIntentResult with unified confidence and candidates.
"""
try:
from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session)
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
if not rules:
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=[],
)
router = IntentRouter()
try:
fusion_result = await router.match_hybrid(
message=request.user_message,
rules=rules,
tenant_id=tenant_id,
)
hybrid_result = HybridIntentResult.from_fusion_result(fusion_result)
logger.info(
f"[AC-CLARIFY] Hybrid intent match: "
f"intent={hybrid_result.intent.intent_name if hybrid_result.intent else None}, "
f"confidence={hybrid_result.confidence:.3f}, "
f"need_clarify={hybrid_result.need_clarify}"
)
return hybrid_result
except Exception as e:
logger.warning(f"[AC-CLARIFY] Hybrid match failed, fallback to rule: {e}")
result = router.match(request.user_message, rules)
if result:
confidence = clarification_engine.compute_confidence(
rule_score=1.0,
semantic_score=0.0,
llm_score=0.0,
)
candidate = IntentCandidate(
intent_id=str(result.rule.id),
intent_name=result.rule.name,
confidence=confidence,
response_type=result.rule.response_type,
target_kb_ids=result.rule.target_kb_ids,
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
fixed_reply=result.rule.fixed_reply,
transfer_message=result.rule.transfer_message,
)
return HybridIntentResult(
intent=candidate,
confidence=confidence,
candidates=[candidate],
)
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=[],
)
except Exception as e:
logger.warning(f"[AC-CLARIFY] Intent match failed: {e}")
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=[],
)
async def _handle_clarification(
tenant_id: str,
request: DialogueRequest,
hybrid_result: HybridIntentResult,
clarification_engine: ClarificationEngine,
trace: TraceInfo,
session: AsyncSession,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> DialogueResponse | None:
"""
[AC-CLARIFY] Handle clarification logic.
Returns DialogueResponse if clarification is needed, None otherwise.
"""
existing_state = ClarifySessionManager.get_session(request.session_id)
if existing_state and not existing_state.is_max_retry():
logger.info(
f"[AC-CLARIFY] Processing clarify response: "
f"session={request.session_id}, retry={existing_state.retry_count}"
)
new_result = clarification_engine.process_clarify_response(
user_message=request.user_message,
state=existing_state,
)
if not new_result.need_clarify:
ClarifySessionManager.clear_session(request.session_id)
if new_result.intent:
return None
clarify_prompt = clarification_engine.generate_clarify_prompt(existing_state)
return DialogueResponse(
segments=[Segment(text=clarify_prompt, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code=f"clarify_retry_{existing_state.retry_count}",
intent=existing_state.candidates[0].intent_name if existing_state.candidates else None,
),
)
should_clarify, clarify_state = clarification_engine.should_trigger_clarify(
result=hybrid_result,
required_slots=required_slots,
filled_slots=filled_slots,
)
if not should_clarify or not clarify_state:
return None
logger.info(
f"[AC-CLARIFY] Clarification triggered: "
f"session={request.session_id}, reason={clarify_state.reason}, "
f"confidence={hybrid_result.confidence:.3f}"
)
ClarifySessionManager.set_session(request.session_id, clarify_state)
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
return DialogueResponse(
segments=[Segment(text=clarify_prompt, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code=f"clarify_{clarify_state.reason.value}",
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
),
)
async def _check_hard_block(
tenant_id: str,
request: DialogueRequest,
hybrid_result: HybridIntentResult,
clarification_engine: ClarificationEngine,
trace: TraceInfo,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> DialogueResponse | None:
"""
[AC-CLARIFY] Check hard block conditions.
Hard blocks:
1. confidence < T_high: block entering new flow
2. required_slots missing: block flow progression
Returns DialogueResponse if blocked, None otherwise.
"""
is_blocked, block_reason = clarification_engine.check_hard_block(
result=hybrid_result,
required_slots=required_slots,
filled_slots=filled_slots,
)
if not is_blocked:
return None
logger.info(
f"[AC-CLARIFY] Hard block triggered: "
f"session={request.session_id}, reason={block_reason}, "
f"confidence={hybrid_result.confidence:.3f}"
)
clarify_state = ClarifyState(
reason=block_reason,
asked_slot=required_slots[0] if required_slots and block_reason == ClarifyReason.MISSING_SLOT else None,
candidates=hybrid_result.candidates,
)
ClarifySessionManager.set_session(request.session_id, clarify_state)
clarification_engine._metrics.record_clarify_trigger()
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
return DialogueResponse(
segments=[Segment(text=clarify_prompt, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code=f"hard_block_{block_reason.value}",
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
),
)
async def _handle_legacy_response(
tenant_id: str,
request: DialogueRequest,
@ -793,18 +1068,32 @@ async def _execute_agent_mode(
session: AsyncSession | None = None,
tool_registry: ToolRegistry | None = None,
) -> DialogueResponse:
"""[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall."""
"""
[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13, AC-MRS-SLOT-META-03]
Execute agent mode with ReAct loop, KB tool, memory recall, and slot state aggregation.
"""
from app.services.llm.factory import get_llm_config_manager
from app.services.mid.slot_state_aggregator import SlotStateAggregator
logger.info(
f"[DEBUG-AGENT] Starting _execute_agent_mode: tenant={tenant_id}, "
f"session={request.session_id}, scene={request.scene}, user_id={request.user_id}, "
f"user_message={request.user_message[:100] if request.user_message else 'None'}..."
)
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
logger.info(f"[DEBUG-AGENT] LLM client obtained successfully")
except Exception as e:
logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}")
llm_client = None
base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {}
if request.scene:
base_context["scene"] = request.scene
if interrupt_ctx and interrupt_ctx.consumed:
base_context["interrupted_content"] = interrupt_ctx.interrupted_content
base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids
@ -813,8 +1102,50 @@ async def _execute_agent_mode(
f"{len(interrupt_ctx.interrupted_content or '')} chars"
)
# [AC-MRS-SLOT-META-03] 初始化槽位状态聚合器
slot_state = None
memory_context = ""
memory_missing_slots: list[str] = []
scene_slot_context = None
flow_status = None
is_flow_runtime = False
if session:
from app.services.flow.engine import FlowEngine
flow_engine = FlowEngine(session)
flow_status = await flow_engine.get_flow_status(tenant_id, request.session_id)
is_flow_runtime = bool(flow_status and flow_status.get("status") == "active")
logger.info(
f"[AC-IDMP-02] 运行轨道判定: session={request.session_id}, "
f"is_flow_runtime={is_flow_runtime}, flow_id={flow_status.get('flow_id') if flow_status else None}, "
f"current_step={flow_status.get('current_step') if flow_status else None}"
)
# [AC-SCENE-SLOT-02] 加载场景槽位包(仅流程态加载)
if is_flow_runtime and session and request.scene:
from app.services.mid.scene_slot_bundle_loader import SceneSlotBundleLoader
scene_loader = SceneSlotBundleLoader(session)
scene_slot_context = await scene_loader.load_scene_context(
tenant_id=tenant_id,
scene_key=request.scene,
)
if scene_slot_context:
logger.info(
f"[AC-SCENE-SLOT-02] 场景槽位包加载成功: "
f"scene={request.scene}, required={len(scene_slot_context.required_slots)}, "
f"optional={len(scene_slot_context.optional_slots)}, "
f"threshold={scene_slot_context.completion_threshold}"
)
base_context["scene_slot_context"] = {
"scene_key": scene_slot_context.scene_key,
"scene_name": scene_slot_context.scene_name,
"required_slots": scene_slot_context.get_required_slot_keys(),
"optional_slots": scene_slot_context.get_optional_slot_keys(),
}
if session and request.user_id:
memory_recall_tool = MemoryRecallTool(
session=session,
@ -855,11 +1186,111 @@ async def _execute_agent_mode(
f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}"
)
# [AC-MRS-SLOT-META-03] 仅流程态执行槽位聚合/提取
if is_flow_runtime:
slot_aggregator = SlotStateAggregator(
session=session,
tenant_id=tenant_id,
session_id=request.session_id,
)
slot_state = await slot_aggregator.aggregate(
memory_slots=memory_result.slots,
current_input_slots=None, # 可从 request 中解析
context=base_context,
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
)
logger.info(
f"[AC-MRS-SLOT-META-03] 流程态槽位聚合完成: "
f"filled={len(slot_state.filled_slots)}, "
f"missing={len(slot_state.missing_required_slots)}, "
f"mappings={slot_state.slot_to_field_map}"
)
# [AC-MRS-SLOT-EXTRACT-01] 自动提取槽位
if slot_state and slot_state.missing_required_slots:
from app.services.mid.slot_extraction_integration import SlotExtractionIntegration
extraction_integration = SlotExtractionIntegration(
session=session,
tenant_id=tenant_id,
session_id=request.session_id,
)
extraction_result = await extraction_integration.extract_and_fill(
user_input=request.user_message,
slot_state=slot_state,
)
if extraction_result.extracted_slots:
slot_state = await slot_aggregator.aggregate(
memory_slots=memory_result.slots,
current_input_slots=extraction_result.extracted_slots,
context=base_context,
)
logger.info(
f"[AC-MRS-SLOT-EXTRACT-01] 流程态自动提槽完成: "
f"extracted={list(extraction_result.extracted_slots.keys())}, "
f"time_ms={extraction_result.total_execution_time_ms:.2f}"
)
else:
logger.info(
f"[AC-MRS-SLOT-META-03] 当前为通用问答轨道,跳过槽位聚合与缺槽追问: "
f"session={request.session_id}"
)
kb_hits = []
kb_success = False
kb_fallback_reason = None
kb_applied_filter = {}
kb_missing_slots = []
kb_dynamic_result = None
step_kb_binding_trace: dict[str, Any] | None = None
# [Step-KB-Binding] 获取当前流程步骤的 KB 配置(仅流程态)
step_kb_config = None
if is_flow_runtime and session and flow_status:
current_step_no = flow_status.get("current_step")
flow_id = flow_status.get("flow_id")
if flow_id and current_step_no:
from app.models.entities import ScriptFlow
from sqlalchemy import select
stmt = select(ScriptFlow).where(ScriptFlow.id == flow_id)
result = await session.execute(stmt)
flow = result.scalar_one_or_none()
if flow and flow.steps:
step_idx = current_step_no - 1
if 0 <= step_idx < len(flow.steps):
current_step = flow.steps[step_idx]
# 构建 StepKbConfig
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
step_kb_config = StepKbConfig(
allowed_kb_ids=current_step.get("allowed_kb_ids"),
preferred_kb_ids=current_step.get("preferred_kb_ids"),
kb_query_hint=current_step.get("kb_query_hint"),
max_kb_calls=current_step.get("max_kb_calls_per_step", 1),
step_id=f"{flow_id}_step_{current_step_no}",
)
step_kb_binding_trace = {
"flow_id": str(flow_id),
"flow_name": flow_status.get("flow_name"),
"current_step": current_step_no,
"step_id": step_kb_config.step_id,
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
}
logger.info(
f"[Step-KB-Binding] 步骤知识库配置加载成功: "
f"flow={flow_status.get('flow_name')}, step={current_step_no}, "
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}"
)
if session and tool_registry:
kb_tool = KbSearchDynamicTool(
@ -868,17 +1299,21 @@ async def _execute_agent_mode(
config=KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=2000,
timeout_ms=10000,
min_score_threshold=0.5,
),
)
# [AC-MRS-SLOT-META-03] 传入 slot_state 进行 KB 检索
# [Step-KB-Binding] 传入 step_kb_config 进行步骤级别的 KB 约束
kb_dynamic_result = await kb_tool.execute(
query=request.user_message,
tenant_id=tenant_id,
scene="open_consult",
top_k=5,
context=base_context,
slot_state=slot_state,
step_kb_config=step_kb_config,
slot_policy="flow_strict" if is_flow_runtime else "agent_relaxed",
)
kb_success = kb_dynamic_result.success
@ -894,9 +1329,43 @@ async def _execute_agent_mode(
)
logger.info(
f"[AC-MARH-05] KB dynamic search: success={kb_success}, "
f"[AC-MARH-05] KB动态检索完成: success={kb_success}, "
f"hits={len(kb_hits)}, filter={kb_applied_filter}, "
f"missing_slots={kb_missing_slots}"
f"missing_slots={kb_missing_slots}, track={'flow' if is_flow_runtime else 'agent'}"
)
# [AC-MRS-SLOT-META-03] 处理缺失必填槽位 -> 追问闭环(仅流程态)
if is_flow_runtime and kb_fallback_reason == "MISSING_REQUIRED_SLOTS" and kb_missing_slots:
ask_back_text = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=kb_missing_slots,
session=session,
tenant_id=tenant_id,
session_id=request.session_id,
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
)
logger.info(
f"[AC-MRS-SLOT-META-03] 流程态缺槽追问文案已生成: "
f"{ask_back_text[:50]}..."
)
return DialogueResponse(
segments=[Segment(text=ask_back_text, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code="missing_required_slots",
kb_tool_called=True,
kb_hit=False,
# [AC-SCENE-SLOT-02] 场景槽位追踪
scene=request.scene,
scene_slot_context=base_context.get("scene_slot_context"),
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
ask_back_triggered=True,
slot_sources=slot_state.slot_sources if slot_state else None,
),
)
else:
kb_result = await default_kb_tool_runner.execute(
@ -933,7 +1402,18 @@ async def _execute_agent_mode(
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
template_service=PromptTemplateService,
variable_resolver=VariableResolver(),
tenant_id=tenant_id,
user_id=request.user_id,
session_id=request.session_id,
scene=request.scene,
)
logger.info(
f"[DEBUG-AGENT] Calling orchestrator.execute with: "
f"user_message={request.user_message[:100] if request.user_message else 'None'}, "
f"context_keys={list(base_context.keys())}, llm_client={llm_client is not None}"
)
final_answer, react_ctx, agent_trace = await orchestrator.execute(
@ -941,6 +1421,12 @@ async def _execute_agent_mode(
context=base_context,
)
logger.info(
f"[DEBUG-AGENT] orchestrator.execute completed: "
f"final_answer={final_answer[:200] if final_answer else 'None'}, "
f"iterations={react_ctx.iteration}, tool_calls={len(react_ctx.tool_calls) if react_ctx.tool_calls else 0}"
)
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
trace_logger.update_trace(
@ -964,10 +1450,114 @@ async def _execute_agent_mode(
kb_tool_called=True,
kb_hit=kb_success and len(kb_hits) > 0,
fallback_reason_code=kb_fallback_reason,
# [AC-SCENE-SLOT-02] 场景槽位追踪
scene=request.scene,
scene_slot_context=base_context.get("scene_slot_context"),
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
ask_back_triggered=False,
slot_sources=slot_state.slot_sources if slot_state else None,
kb_filter_sources=kb_dynamic_result.filter_sources if kb_dynamic_result else None,
# [Step-KB-Binding] 步骤知识库绑定追踪
step_kb_binding=kb_dynamic_result.step_kb_binding if kb_dynamic_result else step_kb_binding_trace,
),
)
async def _generate_ask_back_for_missing_slots(
slot_state: Any,
missing_slots: list[dict[str, str]],
session: AsyncSession,
tenant_id: str,
session_id: str | None = None,
max_ask_back_slots: int = 2,
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
) -> str:
"""
[AC-MRS-SLOT-META-03, AC-MRS-SLOT-ASKBACK-01] 为缺失的必填槽位生成追问响应
[AC-SCENE-SLOT-02] 支持场景槽位包配置的追问策略
支持批量追问多个缺失槽位优先追问必填槽位
"""
if not missing_slots:
return "请提供更多信息以便我更好地帮助您。"
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景配置的追问策略
if scene_slot_context:
ask_back_order = getattr(scene_slot_context, 'ask_back_order', 'priority')
if ask_back_order == "parallel":
prompts = []
for missing in missing_slots[:max_ask_back_slots]:
ask_back_prompt = missing.get("ask_back_prompt")
if ask_back_prompt:
prompts.append(ask_back_prompt)
else:
slot_key = missing.get("slot_key", "相关信息")
prompts.append(f"您的{slot_key}")
if len(prompts) == 1:
return prompts[0]
elif len(prompts) == 2:
return f"为了更好地为您服务,请告诉我{prompts[0]}{prompts[1]}"
else:
all_but_last = "".join(prompts[:-1])
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}"
if len(missing_slots) == 1:
first_missing = missing_slots[0]
ask_back_prompt = first_missing.get("ask_back_prompt")
if ask_back_prompt:
return ask_back_prompt
slot_key = first_missing.get("slot_key", "相关信息")
label = first_missing.get("label", slot_key)
return f"为了更好地为您提供帮助,请告诉我您的{label}"
try:
from app.services.mid.batch_ask_back_service import (
BatchAskBackConfig,
BatchAskBackService,
)
config = BatchAskBackConfig(
max_ask_back_slots_per_turn=max_ask_back_slots,
prefer_required=True,
merge_prompts=True,
)
ask_back_service = BatchAskBackService(
session=session,
tenant_id=tenant_id,
session_id=session_id or "",
config=config,
)
result = await ask_back_service.generate_batch_ask_back(
missing_slots=missing_slots,
)
if result.has_ask_back():
return result.get_prompt()
except Exception as e:
logger.warning(f"[AC-MRS-SLOT-ASKBACK-01] Batch ask-back failed, fallback to single: {e}")
prompts = []
for missing in missing_slots[:max_ask_back_slots]:
ask_back_prompt = missing.get("ask_back_prompt")
if ask_back_prompt:
prompts.append(ask_back_prompt)
else:
label = missing.get("label", missing.get("slot_key", "相关信息"))
prompts.append(f"您的{label}")
if len(prompts) == 1:
return prompts[0]
elif len(prompts) == 2:
return f"为了更好地为您服务,请告诉我{prompts[0]}{prompts[1]}"
else:
all_but_last = "".join(prompts[:-1])
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}"
async def _execute_micro_flow_mode(
tenant_id: str,
request: DialogueRequest,

View File

@ -8,6 +8,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Path
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
@ -26,6 +27,82 @@ router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"])
_session_modes: dict[str, SessionMode] = {}
class CancelFlowResponse(BaseModel):
"""Response for cancel flow operation."""
success: bool
message: str
session_id: str
@router.post(
"/sessions/{sessionId}/cancel-flow",
operation_id="cancelActiveFlow",
summary="Cancel active flow",
description="""
Cancel the active flow for a session.
Use this when you encounter "Session already has an active flow" error.
""",
responses={
200: {"description": "Flow cancelled successfully", "model": CancelFlowResponse},
404: {"description": "No active flow found"},
},
)
async def cancel_active_flow(
sessionId: Annotated[str, Path(description="Session ID")],
session: Annotated[AsyncSession, Depends(get_session)],
) -> CancelFlowResponse:
"""
Cancel the active flow for a session.
This endpoint allows you to cancel any active flow instance
so that a new flow can be started.
"""
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
logger.info(
f"[Cancel Flow] Cancelling active flow: tenant={tenant_id}, session={sessionId}"
)
try:
from app.services.flow.engine import FlowEngine
flow_engine = FlowEngine(session)
cancelled = await flow_engine.cancel_flow(
tenant_id=tenant_id,
session_id=sessionId,
reason="User requested cancellation via API",
)
if cancelled:
logger.info(f"[Cancel Flow] Flow cancelled: session={sessionId}")
return CancelFlowResponse(
success=True,
message="Active flow cancelled successfully",
session_id=sessionId,
)
else:
logger.info(f"[Cancel Flow] No active flow found: session={sessionId}")
return CancelFlowResponse(
success=True,
message="No active flow found for this session",
session_id=sessionId,
)
except Exception as e:
logger.error(f"[Cancel Flow] Failed to cancel flow: {e}")
return CancelFlowResponse(
success=False,
message=f"Failed to cancel flow: {str(e)}",
session_id=sessionId,
)
@router.post(
"/sessions/{sessionId}/mode",
operation_id="switchSessionMode",

View File

@ -64,6 +64,7 @@ class Settings(BaseSettings):
redis_enabled: bool = True
dashboard_cache_ttl: int = 60
stats_counter_ttl: int = 7776000
slot_state_cache_ttl: int = 1800
frontend_base_url: str = "http://localhost:3000"

View File

@ -20,7 +20,7 @@ engine = create_async_engine(
settings.database_url,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
echo=settings.debug,
echo=False,
pool_pre_ping=True,
)

View File

@ -114,7 +114,12 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
from app.core.database import async_session_maker
async with async_session_maker() as session:
await service.initialize(session)
if service._initialized and len(service._keys_cache) > 0:
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
elif service._initialized and len(service._keys_cache) == 0:
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
else:
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
except Exception as e:
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")

View File

@ -272,20 +272,24 @@ class QdrantClient:
score_threshold: float | None = None,
vector_name: str = "full",
with_vectors: bool = False,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""
[AC-AISVC-10] Search vectors in tenant's collection.
[AC-AISVC-10] Search vectors in tenant's collections.
Returns results with score >= score_threshold if specified.
Searches both old format (with @) and new format (with _) for backward compatibility.
Searches all collections for the tenant (multi-KB support).
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
limit: Maximum number of results
limit: Maximum number of results per collection
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup.
with_vectors: Whether to return vectors in results (for two-stage reranking)
metadata_filter: Optional metadata filter to apply during search
kb_ids: Optional list of knowledge base IDs to restrict search to specific KBs
"""
client = await self.get_client()
@ -293,21 +297,36 @@ class QdrantClient:
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
)
if metadata_filter:
logger.info(f"[AC-AISVC-10] Metadata filter: {metadata_filter}")
collection_names = [self.get_collection_name(tenant_id)]
if '@' in tenant_id:
old_format = f"{self._collection_prefix}{tenant_id}"
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
collection_names = [new_format, old_format]
# 构建 Qdrant filter
qdrant_filter = None
if metadata_filter:
qdrant_filter = self._build_qdrant_filter(metadata_filter)
logger.info(f"[AC-AISVC-10] Qdrant filter: {qdrant_filter}")
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
# 获取该租户的所有 collections
collection_names = await self._get_tenant_collections(client, tenant_id)
# 如果指定了 kb_ids则只搜索指定的知识库 collections
if kb_ids:
target_collections = []
for kb_id in kb_ids:
kb_collection_name = self.get_kb_collection_name(tenant_id, kb_id)
if kb_collection_name in collection_names:
target_collections.append(kb_collection_name)
else:
logger.warning(f"[AC-AISVC-10] KB collection not found: {kb_collection_name} for kb_id={kb_id}")
collection_names = target_collections
logger.info(f"[AC-AISVC-10] Restricted to {len(collection_names)} KB collections: {collection_names}")
else:
logger.info(f"[AC-AISVC-10] Will search in {len(collection_names)} collections: {collection_names}")
all_hits = []
for collection_name in collection_names:
try:
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
exists = await client.collection_exists(collection_name)
if not exists:
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
@ -321,6 +340,7 @@ class QdrantClient:
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
query_filter=qdrant_filter,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
@ -334,6 +354,7 @@ class QdrantClient:
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
query_filter=qdrant_filter,
)
else:
raise
@ -348,6 +369,7 @@ class QdrantClient:
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
"collection": collection_name, # 添加 collection 信息
}
if with_vectors and result.vector:
hit["vector"] = result.vector
@ -358,10 +380,6 @@ class QdrantClient:
logger.info(
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
)
for i, h in enumerate(hits[:3]):
logger.debug(
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
)
else:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
@ -370,9 +388,10 @@ class QdrantClient:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
)
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
# 按分数排序并返回 top results
all_hits.sort(key=lambda x: x["score"], reverse=True)
all_hits = all_hits[:limit]
logger.info(
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
@ -386,6 +405,113 @@ class QdrantClient:
return all_hits
async def _get_tenant_collections(
self,
client: AsyncQdrantClient,
tenant_id: str,
) -> list[str]:
"""
获取指定租户的所有 collections
优先从 Redis 缓存获取未缓存则从 Qdrant 查询并缓存
Args:
client: Qdrant client
tenant_id: 租户 ID
Returns:
Collection 名称列表
"""
import time
start_time = time.time()
# 1. 尝试从缓存获取
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
cache_key = f"collections:{tenant_id}"
try:
# 确保 Redis 连接已初始化
redis_client = await cache_service._get_redis()
if redis_client and cache_service._enabled:
cached = await redis_client.get(cache_key)
if cached:
import json
collections = json.loads(cached)
logger.info(
f"[AC-AISVC-10] Cache hit: Found {len(collections)} collections "
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
)
return collections
except Exception as e:
logger.warning(f"[AC-AISVC-10] Cache get error: {e}")
# 2. 从 Qdrant 查询
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"{self._collection_prefix}{safe_tenant_id}"
try:
collections = await client.get_collections()
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
# 按名称排序
tenant_collections.sort()
db_time = (time.time() - start_time) * 1000
logger.info(
f"[AC-AISVC-10] Found {len(tenant_collections)} collections from Qdrant "
f"for tenant={tenant_id} in {db_time:.2f}ms: {tenant_collections}"
)
# 3. 缓存结果5分钟 TTL
try:
redis_client = await cache_service._get_redis()
if redis_client and cache_service._enabled:
import json
await redis_client.setex(
cache_key,
300, # 5分钟
json.dumps(tenant_collections)
)
logger.info(f"[AC-AISVC-10] Cached collections for tenant={tenant_id}")
except Exception as e:
logger.warning(f"[AC-AISVC-10] Cache set error: {e}")
return tenant_collections
except Exception as e:
logger.error(f"[AC-AISVC-10] Failed to get collections for tenant={tenant_id}: {e}")
return [self.get_collection_name(tenant_id)]
def _build_qdrant_filter(
self,
metadata_filter: dict[str, Any],
) -> Any:
"""
构建 Qdrant 过滤条件
Args:
metadata_filter: 元数据过滤条件 {"grade": "三年级", "subject": "语文"}
Returns:
Qdrant Filter 对象
"""
from qdrant_client.models import FieldCondition, Filter, MatchValue
must_conditions = []
for key, value in metadata_filter.items():
# 支持嵌套 metadata 字段,如 metadata.grade
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
return Filter(must=must_conditions) if must_conditions else None
async def delete_collection(self, tenant_id: str) -> bool:
"""
[AC-AISVC-10] Delete tenant's collection.

View File

@ -29,6 +29,8 @@ from app.api.admin import (
monitoring_router,
prompt_templates_router,
rag_router,
retrieval_strategy_router,
scene_slot_bundle_router,
script_flows_router,
sessions_router,
slot_definition_router,
@ -55,6 +57,11 @@ logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
@ -88,6 +95,28 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True)
try:
from app.services.mid.tool_guide_registry import init_tool_guide_registry
logger.info("[ToolGuideRegistry] Starting tool guides initialization...")
tool_guide_registry = init_tool_guide_registry()
logger.info(f"[ToolGuideRegistry] Tool guides loaded: {tool_guide_registry.list_tools()}")
except Exception as e:
logger.error(f"[ToolRegistry] Tools initialization FAILED: {e}", exc_info=True)
# [AC-AISVC-29] 预初始化 Embedding 服务,避免首次查询时的延迟
try:
from app.services.embedding import get_embedding_provider
logger.info("[AC-AISVC-29] Pre-initializing embedding service...")
embedding_provider = await get_embedding_provider()
logger.info(
f"[AC-AISVC-29] Embedding service pre-initialized: "
f"provider={embedding_provider.PROVIDER_NAME}"
)
except Exception as e:
logger.error(f"[AC-AISVC-29] Embedding service pre-initialization FAILED: {e}", exc_info=True)
yield
await close_db()
@ -171,6 +200,8 @@ app.include_router(metadata_schema_router)
app.include_router(monitoring_router)
app.include_router(prompt_templates_router)
app.include_router(rag_router)
app.include_router(retrieval_strategy_router)
app.include_router(scene_slot_bundle_router)
app.include_router(script_flows_router)
app.include_router(sessions_router)
app.include_router(slot_definition_router)

View File

@ -41,6 +41,7 @@ class ChatMessage(SQLModel, table=True):
[AC-AISVC-13] Chat message entity with tenant isolation.
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
[v0.7.0] Extended with monitoring fields for Dashboard statistics.
[v0.8.0] Extended with route_trace for hybrid routing observability.
"""
__tablename__ = "chat_messages"
@ -90,6 +91,11 @@ class ChatMessage(SQLModel, table=True):
sa_column=Column("guardrail_words", JSON, nullable=True),
description="[v0.7.0] Guardrail trigger details: words, categories, strategy"
)
route_trace: dict[str, Any] | None = Field(
default=None,
sa_column=Column("route_trace", JSON, nullable=True),
description="[v0.8.0] Intent routing trace log for hybrid routing observability"
)
class ChatSessionCreate(SQLModel):
@ -227,6 +233,7 @@ class Document(SQLModel, table=True):
file_type: str | None = Field(default=None, description="File MIME type")
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
error_msg: str | None = Field(default=None, description="Error message if failed")
doc_metadata: dict | None = Field(default=None, sa_type=JSON, description="Document metadata as JSON")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
@ -421,6 +428,7 @@ class IntentRule(SQLModel, table=True):
[AC-AISVC-65] Intent rule entity with tenant isolation.
Supports keyword and regex matching for intent recognition.
[AC-IDSMETA-16] Extended with metadata field for unified storage structure.
[v0.8.0] Extended with intent_vector and semantic_examples for hybrid routing.
"""
__tablename__ = "intent_rules"
@ -458,6 +466,16 @@ class IntentRule(SQLModel, table=True):
sa_column=Column("metadata", JSON, nullable=True),
description="[AC-IDSMETA-16] Structured metadata for the intent rule"
)
intent_vector: list[float] | None = Field(
default=None,
sa_column=Column("intent_vector", JSON, nullable=True),
description="[v0.8.0] Pre-computed intent vector for semantic matching"
)
semantic_examples: list[str] | None = Field(
default=None,
sa_column=Column("semantic_examples", JSON, nullable=True),
description="[v0.8.0] Semantic example sentences for dynamic vector computation"
)
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
@ -475,6 +493,8 @@ class IntentRuleCreate(SQLModel):
fixed_reply: str | None = None
transfer_message: str | None = None
metadata_: dict[str, Any] | None = None
intent_vector: list[float] | None = None
semantic_examples: list[str] | None = None
class IntentRuleUpdate(SQLModel):
@ -491,6 +511,8 @@ class IntentRuleUpdate(SQLModel):
transfer_message: str | None = None
is_enabled: bool | None = None
metadata_: dict[str, Any] | None = None
intent_vector: list[float] | None = None
semantic_examples: list[str] | None = None
class IntentMatchResult:
@ -810,6 +832,24 @@ class FlowStep(SQLModel):
default=None,
description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}"
)
allowed_kb_ids: list[str] | None = Field(
default=None,
description="[Step-KB-Binding] Allowed knowledge base IDs for this step. If set, KB search will be restricted to these KBs."
)
preferred_kb_ids: list[str] | None = Field(
default=None,
description="[Step-KB-Binding] Preferred knowledge base IDs for this step. These KBs will be searched first."
)
kb_query_hint: str | None = Field(
default=None,
description="[Step-KB-Binding] Query hint for KB search in this step, helps improve retrieval accuracy."
)
max_kb_calls_per_step: int | None = Field(
default=None,
ge=1,
le=5,
description="[Step-KB-Binding] Max KB calls allowed per step. Default is 1 if not set."
)
class ScriptFlowCreate(SQLModel):
@ -1078,6 +1118,7 @@ class MetadataFieldDefinition(SQLModel, table=True):
)
is_filterable: bool = Field(default=True, description="是否可用于过滤")
is_rank_feature: bool = Field(default=False, description="是否用于排序特征")
usage_description: str | None = Field(default=None, description="用途说明")
field_roles: list[str] = Field(
default_factory=list,
sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"),
@ -1104,6 +1145,7 @@ class MetadataFieldDefinitionCreate(SQLModel):
scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value])
is_filterable: bool = Field(default=True)
is_rank_feature: bool = Field(default=False)
usage_description: str | None = None
field_roles: list[str] = Field(default_factory=list)
status: str = Field(default=MetadataFieldStatus.DRAFT.value)
@ -1118,6 +1160,7 @@ class MetadataFieldDefinitionUpdate(SQLModel):
scope: list[str] | None = None
is_filterable: bool | None = None
is_rank_feature: bool | None = None
usage_description: str | None = None
field_roles: list[str] | None = None
status: str | None = None
@ -1131,6 +1174,17 @@ class ExtractStrategy(str, Enum):
USER_INPUT = "user_input"
class ExtractFailureType(str, Enum):
"""
[AC-MRS-07-UPGRADE] 提取失败类型
统一失败分类用于追踪和日志
"""
EXTRACT_EMPTY = "EXTRACT_EMPTY" # 提取结果为空
EXTRACT_PARSE_FAIL = "EXTRACT_PARSE_FAIL" # 解析失败
EXTRACT_VALIDATION_FAIL = "EXTRACT_VALIDATION_FAIL" # 校验失败
EXTRACT_RUNTIME_ERROR = "EXTRACT_RUNTIME_ERROR" # 运行时错误
class SlotValueSource(str, Enum):
"""
[AC-MRS-09] 槽位值来源
@ -1145,6 +1199,7 @@ class SlotDefinition(SQLModel, table=True):
"""
[AC-MRS-07,08] 槽位定义表
独立的槽位定义模型与元数据字段解耦但可复用
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
"""
__tablename__ = "slot_definitions"
@ -1162,14 +1217,31 @@ class SlotDefinition(SQLModel, table=True):
min_length=1,
max_length=100,
)
display_name: str | None = Field(
default=None,
description="槽位名称,给运营/教研看的中文名grade -> '当前年级'",
max_length=100,
)
description: str | None = Field(
default=None,
description="槽位说明,解释这个槽位采集什么、用于哪里",
max_length=500,
)
type: str = Field(
default=MetadataFieldType.STRING.value,
description="槽位类型: string/number/boolean/enum/array_enum"
)
required: bool = Field(default=False, description="是否必填槽位")
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容读取
extract_strategy: str | None = Field(
default=None,
description="提取策略: rule/llm/user_input"
description="[兼容字段] 提取策略: rule/llm/user_input已废弃请使用 extract_strategies"
)
# [AC-MRS-07-UPGRADE] 新增策略链字段
extract_strategies: list[str] | None = Field(
default=None,
sa_column=Column("extract_strategies", JSON, nullable=True),
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
validation_rule: str | None = Field(
default=None,
@ -1192,14 +1264,72 @@ class SlotDefinition(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
def get_effective_strategies(self) -> list[str]:
"""
[AC-MRS-07-UPGRADE] 获取有效的提取策略链
优先使用 extract_strategies如果不存在则兼容读取 extract_strategy
"""
if self.extract_strategies and len(self.extract_strategies) > 0:
return self.extract_strategies
if self.extract_strategy:
return [self.extract_strategy]
return []
def validate_strategies(self) -> tuple[bool, str]:
"""
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
Returns:
Tuple of (是否有效, 错误信息)
"""
valid_strategies = {"rule", "llm", "user_input"}
strategies = self.get_effective_strategies()
if not strategies:
return True, "" # 空策略链视为有效(使用默认行为)
# 校验至少1个策略
if len(strategies) == 0:
return False, "提取策略链不能为空"
# 校验不允许重复策略
if len(strategies) != len(set(strategies)):
return False, "提取策略链中不允许重复的策略"
# 校验策略值有效
invalid = [s for s in strategies if s not in valid_strategies]
if invalid:
return False, f"无效的提取策略: {invalid},有效值为: {list(valid_strategies)}"
return True, ""
class SlotDefinitionCreate(SQLModel):
"""[AC-MRS-07,08] 创建槽位定义"""
slot_key: str = Field(..., min_length=1, max_length=100)
display_name: str | None = Field(
default=None,
description="槽位名称,给运营/教研看的中文名",
max_length=100,
)
description: str | None = Field(
default=None,
description="槽位说明,解释这个槽位采集什么、用于哪里",
max_length=500,
)
type: str = Field(default=MetadataFieldType.STRING.value)
required: bool = Field(default=False)
extract_strategy: str | None = None
# [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = None
ask_back_prompt: str | None = None
default_value: dict[str, Any] | None = None
@ -1209,9 +1339,28 @@ class SlotDefinitionCreate(SQLModel):
class SlotDefinitionUpdate(SQLModel):
"""[AC-MRS-07] 更新槽位定义"""
display_name: str | None = Field(
default=None,
description="槽位名称,给运营/教研看的中文名",
max_length=100,
)
description: str | None = Field(
default=None,
description="槽位说明,解释这个槽位采集什么、用于哪里",
max_length=500,
)
type: str | None = None
required: bool | None = None
extract_strategy: str | None = None
# [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = None
ask_back_prompt: str | None = None
default_value: dict[str, Any] | None = None
@ -1522,3 +1671,107 @@ class MidAuditLog(SQLModel, table=True):
high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景")
latency_ms: int | None = Field(default=None, description="总耗时(ms)")
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True)
class SceneSlotBundleStatus(str, Enum):
"""[AC-SCENE-SLOT-01] 场景槽位包状态"""
DRAFT = "draft"
ACTIVE = "active"
DEPRECATED = "deprecated"
class SceneSlotBundle(SQLModel, table=True):
"""
[AC-SCENE-SLOT-01] 场景-槽位映射配置
定义每个场景需要采集的槽位集合
三层关系
- 层1slot metadata通过 linked_field_id
- 层2scene slot_bundle本模型
- 层3step.expected_variables slot_key话术步骤引用
"""
__tablename__ = "scene_slot_bundles"
__table_args__ = (
Index("ix_scene_slot_bundles_tenant", "tenant_id"),
Index("ix_scene_slot_bundles_tenant_scene", "tenant_id", "scene_key", unique=True),
Index("ix_scene_slot_bundles_tenant_status", "tenant_id", "status"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
scene_key: str = Field(
...,
description="场景标识,如 'open_consult', 'refund_apply', 'course_recommend'",
min_length=1,
max_length=100,
)
scene_name: str = Field(
...,
description="场景名称,如 '开放咨询', '退款申请', '课程推荐'",
min_length=1,
max_length=100,
)
description: str | None = Field(
default=None,
description="场景描述"
)
required_slots: list[str] = Field(
default_factory=list,
sa_column=Column("required_slots", JSON, nullable=False),
description="必填槽位 slot_key 列表"
)
optional_slots: list[str] = Field(
default_factory=list,
sa_column=Column("optional_slots", JSON, nullable=False),
description="可选槽位 slot_key 列表"
)
slot_priority: list[str] | None = Field(
default=None,
sa_column=Column("slot_priority", JSON, nullable=True),
description="槽位采集优先级顺序slot_key 列表)"
)
completion_threshold: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="完成阈值0.0-1.0),必填槽位填充比例达到此值视为完成"
)
ask_back_order: str = Field(
default="priority",
description="追问顺序策略: priority/required_first/parallel"
)
status: str = Field(
default=SceneSlotBundleStatus.DRAFT.value,
description="状态: draft/active/deprecated"
)
version: int = Field(default=1, description="版本号")
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
class SceneSlotBundleCreate(SQLModel):
"""[AC-SCENE-SLOT-01] 创建场景槽位包"""
scene_key: str = Field(..., min_length=1, max_length=100)
scene_name: str = Field(..., min_length=1, max_length=100)
description: str | None = None
required_slots: list[str] = Field(default_factory=list)
optional_slots: list[str] = Field(default_factory=list)
slot_priority: list[str] | None = None
completion_threshold: float = Field(default=1.0, ge=0.0, le=1.0)
ask_back_order: str = Field(default="priority")
status: str = Field(default=SceneSlotBundleStatus.DRAFT.value)
class SceneSlotBundleUpdate(SQLModel):
"""[AC-SCENE-SLOT-01] 更新场景槽位包"""
scene_name: str | None = Field(default=None, min_length=1, max_length=100)
description: str | None = None
required_slots: list[str] | None = None
optional_slots: list[str] | None = None
slot_priority: list[str] | None = None
completion_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
ask_back_order: str | None = None
status: str | None = None

View File

@ -73,6 +73,7 @@ class DialogueRequest(BaseModel):
history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史")
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段")
feature_flags: FeatureFlags | None = Field(default=None, description="特性开关")
scene: str | None = Field(default=None, description="场景标识用于KB过滤'open_consult', 'after_sale'")
class Segment(BaseModel):
@ -127,6 +128,7 @@ class TraceInfo(BaseModel):
)
tools_used: list[str] | None = Field(default=None, description="使用的工具列表")
tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪")
step_kb_binding: dict[str, Any] | None = Field(default=None, description="[Step-KB-Binding] 步骤知识库绑定信息")
class DialogueResponse(BaseModel):

View File

@ -86,6 +86,7 @@ class DialogueRequest(BaseModel):
humanize_config: HumanizeConfigRequest | None = Field(
default=None, description="Humanize config for segment delay"
)
scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'")
class Segment(BaseModel):
@ -122,6 +123,8 @@ class ToolCallTrace(BaseModel):
error_code: str | None = Field(default=None, description="Error code if failed")
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
result_digest: str | None = Field(default=None, description="Result digest for logging")
arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments")
result: Any = Field(default=None, description="Full tool call result")
class SegmentStats(BaseModel):
@ -133,7 +136,7 @@ class SegmentStats(BaseModel):
class TraceInfo(BaseModel):
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability."""
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability."""
mode: ExecutionMode = Field(..., description="Execution mode")
intent: str | None = Field(default=None, description="Matched intent")
request_id: str | None = Field(
@ -156,6 +159,17 @@ class TraceInfo(BaseModel):
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds")
created_at: str | None = Field(default=None, description="Creation timestamp")
# [AC-SCENE-SLOT-02] 场景槽位追踪字段
scene: str | None = Field(default=None, description="当前场景标识")
scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息")
missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表")
ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问")
slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射")
kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射")
# [Step-KB-Binding] 步骤知识库绑定追踪
step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等")
class DialogueResponse(BaseModel):

View File

@ -43,6 +43,8 @@ class ToolCallTrace:
- error_code: 错误码
- args_digest: 参数摘要脱敏
- result_digest: 结果摘要
- arguments: 完整参数
- result: 完整结果
"""
tool_name: str
duration_ms: int
@ -53,6 +55,8 @@ class ToolCallTrace:
error_code: str | None = None
args_digest: str | None = None
result_digest: str | None = None
arguments: dict[str, Any] | None = None
result: Any = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: datetime | None = None
@ -74,6 +78,10 @@ class ToolCallTrace:
result["args_digest"] = self.args_digest
if self.result_digest:
result["result_digest"] = self.result_digest
if self.arguments:
result["arguments"] = self.arguments
if self.result is not None:
result["result"] = self.result
return result
@staticmethod

View File

@ -139,7 +139,16 @@ class SlotDefinitionResponse(BaseModel):
slot_key: str = Field(..., description="槽位键名")
type: str = Field(..., description="槽位类型")
required: bool = Field(default=False, description="是否必填槽位")
extract_strategy: str | None = Field(default=None, description="提取策略")
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃"
)
# [AC-MRS-07-UPGRADE] 新增策略链字段
extract_strategies: list[str] | None = Field(
default=None,
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input"
)
validation_rule: str | None = Field(default=None, description="校验规则")
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
default_value: dict[str, Any] | None = Field(default=None, description="默认值")
@ -157,9 +166,15 @@ class SlotDefinitionCreateRequest(BaseModel):
slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名")
type: str = Field(default="string", description="槽位类型")
required: bool = Field(default=False, description="是否必填槽位")
# [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="提取策略: rule/llm/user_input"
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = Field(default=None, description="校验规则")
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
@ -172,7 +187,16 @@ class SlotDefinitionUpdateRequest(BaseModel):
type: str | None = None
required: bool | None = None
extract_strategy: str | None = None
# [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = None
ask_back_prompt: str | None = None
default_value: dict[str, Any] | None = None

View File

@ -0,0 +1,198 @@
"""
Retrieval Strategy Schemas for AI Service.
[AC-AISVC-RES-01~15] Request/Response models aligned with OpenAPI contract.
"""
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, model_validator
class StrategyType(str, Enum):
DEFAULT = "default"
ENHANCED = "enhanced"
class ReactMode(str, Enum):
REACT = "react"
NON_REACT = "non_react"
class RolloutMode(str, Enum):
OFF = "off"
PERCENTAGE = "percentage"
ALLOWLIST = "allowlist"
class ValidationCheckType(str, Enum):
METADATA_CONSISTENCY = "metadata_consistency"
EMBEDDING_PREFIX = "embedding_prefix"
RRF_CONFIG = "rrf_config"
PERFORMANCE_BUDGET = "performance_budget"
class RolloutConfig(BaseModel):
"""
[AC-AISVC-RES-03] Grayscale rollout configuration.
"""
mode: RolloutMode = Field(..., description="Rollout mode: off, percentage, or allowlist")
percentage: float | None = Field(
default=None,
ge=0,
le=100,
description="Percentage of traffic for grayscale (0-100)",
)
allowlist: list[str] | None = Field(
default=None,
description="List of tenant IDs in allowlist",
)
@model_validator(mode="after")
def validate_rollout_config(self) -> "RolloutConfig":
if self.mode == RolloutMode.PERCENTAGE and self.percentage is None:
raise ValueError("percentage is required when mode is 'percentage'")
if self.mode == RolloutMode.ALLOWLIST and (
self.allowlist is None or len(self.allowlist) == 0
):
raise ValueError("allowlist is required when mode is 'allowlist'")
return self
class RetrievalStrategyStatus(BaseModel):
"""
[AC-AISVC-RES-01] Current retrieval strategy status.
"""
active_strategy: StrategyType = Field(
...,
alias="active_strategy",
description="Current active strategy: default or enhanced",
)
react_mode: ReactMode = Field(
...,
alias="react_mode",
description="ReAct mode: react or non_react",
)
rollout: RolloutConfig = Field(..., description="Grayscale rollout configuration")
model_config = {"populate_by_name": True}
class RetrievalStrategySwitchRequest(BaseModel):
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Request to switch retrieval strategy.
"""
target_strategy: StrategyType = Field(
...,
alias="target_strategy",
description="Target strategy to switch to",
)
react_mode: ReactMode | None = Field(
default=None,
alias="react_mode",
description="ReAct mode to use",
)
rollout: RolloutConfig | None = Field(
default=None,
description="Grayscale rollout configuration",
)
reason: str | None = Field(
default=None,
description="Reason for strategy switch",
)
model_config = {"populate_by_name": True}
class RetrievalStrategySwitchResponse(BaseModel):
"""
[AC-AISVC-RES-02] Response after strategy switch.
"""
previous: RetrievalStrategyStatus = Field(..., description="Previous strategy status")
current: RetrievalStrategyStatus = Field(..., description="Current strategy status")
class RetrievalStrategyValidationRequest(BaseModel):
"""
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Request to validate strategy.
"""
strategy: StrategyType = Field(..., description="Strategy to validate")
react_mode: ReactMode | None = Field(
default=None,
description="ReAct mode to validate",
)
checks: list[ValidationCheckType] | None = Field(
default=None,
description="List of checks to perform",
)
model_config = {"populate_by_name": True}
class ValidationResult(BaseModel):
"""
[AC-AISVC-RES-06] Single validation check result.
"""
check: str = Field(..., description="Check name")
passed: bool = Field(..., description="Whether the check passed")
message: str | None = Field(default=None, description="Additional message")
class RetrievalStrategyValidationResponse(BaseModel):
"""
[AC-AISVC-RES-06] Validation response with all check results.
"""
passed: bool = Field(..., description="Whether all checks passed")
results: list[ValidationResult] = Field(..., description="Individual check results")
class RetrievalStrategyRollbackResponse(BaseModel):
"""
[AC-AISVC-RES-07] Response after strategy rollback.
"""
current: RetrievalStrategyStatus = Field(..., description="Current strategy status before rollback")
rollback_to: RetrievalStrategyStatus = Field(..., description="Strategy status after rollback")
class StrategyAuditLog(BaseModel):
"""
[AC-AISVC-RES-07] Audit log entry for strategy operations.
"""
timestamp: str = Field(..., description="ISO timestamp of the operation")
operation: str = Field(..., description="Operation type: switch, rollback, validate")
previous_strategy: str | None = Field(default=None, description="Previous strategy")
new_strategy: str | None = Field(default=None, description="New strategy")
previous_react_mode: str | None = Field(default=None, description="Previous react mode")
new_react_mode: str | None = Field(default=None, description="New react mode")
reason: str | None = Field(default=None, description="Reason for the operation")
operator: str | None = Field(default=None, description="Operator who performed the operation")
tenant_id: str | None = Field(default=None, description="Tenant ID if applicable")
metadata: dict[str, Any] | None = Field(default=None, description="Additional metadata")
class StrategyMetrics(BaseModel):
"""
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics for strategy operations.
"""
strategy: StrategyType = Field(..., description="Current strategy")
react_mode: ReactMode = Field(..., description="Current react mode")
total_requests: int = Field(default=0, description="Total requests count")
successful_requests: int = Field(default=0, description="Successful requests count")
failed_requests: int = Field(default=0, description="Failed requests count")
avg_latency_ms: float = Field(default=0.0, description="Average latency in ms")
p99_latency_ms: float = Field(default=0.0, description="P99 latency in ms")
direct_route_count: int = Field(default=0, description="Direct route count")
react_route_count: int = Field(default=0, description="React route count")
auto_route_count: int = Field(default=0, description="Auto route count")
fallback_count: int = Field(default=0, description="Fallback to default count")
last_updated: str | None = Field(default=None, description="Last update timestamp")

View File

@ -81,7 +81,6 @@ class ApiKeyService:
return
except Exception as e:
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
await session.rollback()
# Backward-compat fallback for environments without new columns
try:

View File

@ -0,0 +1,274 @@
"""
Scene Slot Bundle Cache Service.
[AC-SCENE-SLOT-03] 场景槽位包缓存服务
职责
1. 缓存场景槽位包配置减少数据库查询
2. 支持缓存失效和刷新
3. 支持租户隔离
"""
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any
logger = logging.getLogger(__name__)
try:
import redis.asyncio as redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
@dataclass
class CachedSceneSlotBundle:
"""缓存的场景槽位包"""
scene_key: str
scene_name: str
description: str | None
required_slots: list[str]
optional_slots: list[str]
slot_priority: list[str] | None
completion_threshold: float
ask_back_order: str
status: str
version: int
cached_at: datetime = field(default_factory=datetime.utcnow)
def to_dict(self) -> dict[str, Any]:
return {
"scene_key": self.scene_key,
"scene_name": self.scene_name,
"description": self.description,
"required_slots": self.required_slots,
"optional_slots": self.optional_slots,
"slot_priority": self.slot_priority,
"completion_threshold": self.completion_threshold,
"ask_back_order": self.ask_back_order,
"status": self.status,
"version": self.version,
"cached_at": self.cached_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "CachedSceneSlotBundle":
return cls(
scene_key=data["scene_key"],
scene_name=data["scene_name"],
description=data.get("description"),
required_slots=data.get("required_slots", []),
optional_slots=data.get("optional_slots", []),
slot_priority=data.get("slot_priority"),
completion_threshold=data.get("completion_threshold", 1.0),
ask_back_order=data.get("ask_back_order", "priority"),
status=data.get("status", "draft"),
version=data.get("version", 1),
cached_at=datetime.fromisoformat(data["cached_at"]) if data.get("cached_at") else datetime.utcnow(),
)
class SceneSlotBundleCache:
"""
[AC-SCENE-SLOT-03] 场景槽位包缓存
使用 Redis 或内存缓存场景槽位包配置
"""
CACHE_PREFIX = "scene_slot_bundle"
CACHE_TTL_SECONDS = 300 # 5分钟缓存
def __init__(
self,
redis_client: Any = None,
ttl_seconds: int = 300,
):
self._redis = redis_client
self._ttl = ttl_seconds or self.CACHE_TTL_SECONDS
self._memory_cache: dict[str, tuple[CachedSceneSlotBundle, datetime]] = {}
def _get_cache_key(self, tenant_id: str, scene_key: str) -> str:
"""生成缓存键"""
return f"{self.CACHE_PREFIX}:{tenant_id}:{scene_key}"
async def get(
self,
tenant_id: str,
scene_key: str,
) -> CachedSceneSlotBundle | None:
"""
获取缓存的场景槽位包
Args:
tenant_id: 租户 ID
scene_key: 场景标识
Returns:
缓存的场景槽位包或 None
"""
cache_key = self._get_cache_key(tenant_id, scene_key)
if self._redis and REDIS_AVAILABLE:
try:
cached_data = await self._redis.get(cache_key)
if cached_data:
data = json.loads(cached_data)
return CachedSceneSlotBundle.from_dict(data)
except Exception as e:
logger.warning(
f"[AC-SCENE-SLOT-03] Redis get failed: {e}, falling back to memory cache"
)
if cache_key in self._memory_cache:
cached_bundle, cached_at = self._memory_cache[cache_key]
if datetime.utcnow() - cached_at < timedelta(seconds=self._ttl):
return cached_bundle
else:
del self._memory_cache[cache_key]
return None
async def set(
self,
tenant_id: str,
scene_key: str,
bundle: CachedSceneSlotBundle,
) -> bool:
"""
设置缓存
Args:
tenant_id: 租户 ID
scene_key: 场景标识
bundle: 要缓存的场景槽位包
Returns:
是否成功
"""
cache_key = self._get_cache_key(tenant_id, scene_key)
if self._redis and REDIS_AVAILABLE:
try:
await self._redis.setex(
cache_key,
self._ttl,
json.dumps(bundle.to_dict()),
)
return True
except Exception as e:
logger.warning(
f"[AC-SCENE-SLOT-03] Redis set failed: {e}, falling back to memory cache"
)
self._memory_cache[cache_key] = (bundle, datetime.utcnow())
return True
async def delete(
self,
tenant_id: str,
scene_key: str,
) -> bool:
"""
删除缓存
Args:
tenant_id: 租户 ID
scene_key: 场景标识
Returns:
是否成功
"""
cache_key = self._get_cache_key(tenant_id, scene_key)
if self._redis and REDIS_AVAILABLE:
try:
await self._redis.delete(cache_key)
except Exception as e:
logger.warning(
f"[AC-SCENE-SLOT-03] Redis delete failed: {e}"
)
if cache_key in self._memory_cache:
del self._memory_cache[cache_key]
return True
async def delete_by_tenant(self, tenant_id: str) -> bool:
"""
删除租户下所有缓存
Args:
tenant_id: 租户 ID
Returns:
是否成功
"""
pattern = f"{self.CACHE_PREFIX}:{tenant_id}:*"
if self._redis and REDIS_AVAILABLE:
try:
keys = []
async for key in self._redis.scan_iter(match=pattern):
keys.append(key)
if keys:
await self._redis.delete(*keys)
except Exception as e:
logger.warning(
f"[AC-SCENE-SLOT-03] Redis delete by tenant failed: {e}"
)
keys_to_delete = [
k for k in self._memory_cache
if k.startswith(f"{self.CACHE_PREFIX}:{tenant_id}:")
]
for key in keys_to_delete:
del self._memory_cache[key]
return True
async def invalidate_on_update(
self,
tenant_id: str,
scene_key: str,
) -> bool:
"""
当场景槽位包更新时使缓存失效
Args:
tenant_id: 租户 ID
scene_key: 场景标识
Returns:
是否成功
"""
logger.info(
f"[AC-SCENE-SLOT-03] Invalidating cache for scene slot bundle: "
f"tenant={tenant_id}, scene={scene_key}"
)
return await self.delete(tenant_id, scene_key)
_scene_slot_bundle_cache: SceneSlotBundleCache | None = None
def get_scene_slot_bundle_cache() -> SceneSlotBundleCache:
"""获取场景槽位包缓存实例"""
global _scene_slot_bundle_cache
if _scene_slot_bundle_cache is None:
_scene_slot_bundle_cache = SceneSlotBundleCache()
return _scene_slot_bundle_cache
def init_scene_slot_bundle_cache(redis_client: Any = None, ttl_seconds: int = 300) -> None:
"""初始化场景槽位包缓存"""
global _scene_slot_bundle_cache
_scene_slot_bundle_cache = SceneSlotBundleCache(
redis_client=redis_client,
ttl_seconds=ttl_seconds,
)
logger.info(
f"[AC-SCENE-SLOT-03] Scene slot bundle cache initialized: "
f"ttl={ttl_seconds}s, redis={redis_client is not None}"
)

View File

@ -0,0 +1,397 @@
"""
Slot State Cache Layer.
槽位状态缓存层 - 提供会话级槽位状态持久化
[AC-MRS-SLOT-CACHE-01] 多轮状态持久化
Features:
- L1: In-memory cache (process-level, 5 min TTL)
- L2: Redis cache (shared, configurable TTL)
- Automatic fallback on cache miss
- Support for slot value source priority
Key format: slot_state:{tenant_id}:{session_id}
TTL: Configurable (default 30 minutes)
"""
import json
import logging
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any
import redis.asyncio as redis
from app.core.config import get_settings
from app.models.mid.schemas import SlotSource
logger = logging.getLogger(__name__)
@dataclass
class CachedSlotValue:
"""
缓存的槽位值
Attributes:
value: 槽位值
source: 值来源 (user_confirmed, rule_extracted, llm_inferred, default, context)
confidence: 置信度
updated_at: 更新时间戳
"""
value: Any
source: str
confidence: float = 1.0
updated_at: float = field(default_factory=time.time)
def to_dict(self) -> dict[str, Any]:
return {
"value": self.value,
"source": self.source,
"confidence": self.confidence,
"updated_at": self.updated_at,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "CachedSlotValue":
return cls(
value=data["value"],
source=data["source"],
confidence=data.get("confidence", 1.0),
updated_at=data.get("updated_at", time.time()),
)
@dataclass
class CachedSlotState:
"""
缓存的槽位状态
Attributes:
filled_slots: 已填充的槽位值字典 {slot_key: CachedSlotValue}
slot_to_field_map: 槽位到元数据字段的映射
created_at: 创建时间
updated_at: 最后更新时间
"""
filled_slots: dict[str, CachedSlotValue] = field(default_factory=dict)
slot_to_field_map: dict[str, str] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
def to_dict(self) -> dict[str, Any]:
return {
"filled_slots": {
k: v.to_dict() for k, v in self.filled_slots.items()
},
"slot_to_field_map": self.slot_to_field_map,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "CachedSlotState":
filled_slots = {}
for k, v in data.get("filled_slots", {}).items():
if isinstance(v, dict):
filled_slots[k] = CachedSlotValue.from_dict(v)
else:
filled_slots[k] = CachedSlotValue(value=v, source="unknown")
return cls(
filled_slots=filled_slots,
slot_to_field_map=data.get("slot_to_field_map", {}),
created_at=data.get("created_at", time.time()),
updated_at=data.get("updated_at", time.time()),
)
def get_simple_filled_slots(self) -> dict[str, Any]:
"""获取简化的已填充槽位字典(仅值)"""
return {k: v.value for k, v in self.filled_slots.items()}
def get_slot_sources(self) -> dict[str, str]:
"""获取槽位来源字典"""
return {k: v.source for k, v in self.filled_slots.items()}
def get_slot_confidence(self) -> dict[str, float]:
"""获取槽位置信度字典"""
return {k: v.confidence for k, v in self.filled_slots.items()}
class SlotStateCache:
"""
[AC-MRS-SLOT-CACHE-01] 槽位状态缓存层
提供会话级槽位状态持久化支持
- L1: 内存缓存进程级5分钟 TTL
- L2: Redis 缓存共享可配置 TTL
- 自动降级Redis 不可用时仅使用内存缓存
- 槽位值来源优先级合并
Key format: slot_state:{tenant_id}:{session_id}
TTL: Configurable via settings.slot_state_cache_ttl (default 1800 seconds = 30 minutes)
"""
_local_cache: dict[str, tuple[CachedSlotState, float]] = {}
_local_cache_ttl = 300
SOURCE_PRIORITY = {
SlotSource.USER_CONFIRMED.value: 100,
"user_confirmed": 100,
SlotSource.RULE_EXTRACTED.value: 80,
"rule_extracted": 80,
SlotSource.LLM_INFERRED.value: 60,
"llm_inferred": 60,
"context": 40,
SlotSource.DEFAULT.value: 20,
"default": 20,
"unknown": 0,
}
def __init__(self, redis_client: redis.Redis | None = None):
self._redis = redis_client
self._settings = get_settings()
self._enabled = self._settings.redis_enabled
self._cache_ttl = getattr(self._settings, "slot_state_cache_ttl", 1800)
async def _get_client(self) -> redis.Redis | None:
"""Get or create Redis client."""
if not self._enabled:
return None
if self._redis is None:
try:
self._redis = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
except Exception as e:
logger.warning(f"[SlotStateCache] Failed to connect to Redis: {e}")
self._enabled = False
return None
return self._redis
def _make_key(self, tenant_id: str, session_id: str) -> str:
"""Generate cache key."""
return f"slot_state:{tenant_id}:{session_id}"
def _make_local_key(self, tenant_id: str, session_id: str) -> str:
"""Generate local cache key."""
return f"{tenant_id}:{session_id}"
def _get_source_priority(self, source: str) -> int:
"""Get priority for a source."""
return self.SOURCE_PRIORITY.get(source, 0)
async def get(
self,
tenant_id: str,
session_id: str,
) -> CachedSlotState | None:
"""
Get cached slot state (L1 -> L2).
Args:
tenant_id: Tenant ID for isolation
session_id: Session ID
Returns:
CachedSlotState or None if not found
"""
local_key = self._make_local_key(tenant_id, session_id)
if local_key in self._local_cache:
state, timestamp = self._local_cache[local_key]
if time.time() - timestamp < self._local_cache_ttl:
logger.debug(f"[SlotStateCache] L1 hit: {local_key}")
return state
else:
del self._local_cache[local_key]
client = await self._get_client()
if client is None:
return None
key = self._make_key(tenant_id, session_id)
try:
data = await client.get(key)
if data:
logger.debug(f"[SlotStateCache] L2 hit: {key}")
state_dict = json.loads(data)
state = CachedSlotState.from_dict(state_dict)
self._local_cache[local_key] = (state, time.time())
return state
return None
except Exception as e:
logger.warning(f"[SlotStateCache] Failed to get from cache: {e}")
return None
async def set(
self,
tenant_id: str,
session_id: str,
state: CachedSlotState,
) -> bool:
"""
Set slot state to cache (L1 + L2).
Args:
tenant_id: Tenant ID for isolation
session_id: Session ID
state: CachedSlotState to cache
Returns:
True if successful
"""
local_key = self._make_local_key(tenant_id, session_id)
state.updated_at = time.time()
self._local_cache[local_key] = (state, time.time())
client = await self._get_client()
if client is None:
return False
key = self._make_key(tenant_id, session_id)
try:
await client.setex(
key,
self._cache_ttl,
json.dumps(state.to_dict(), default=str),
)
logger.debug(f"[SlotStateCache] Set cache: {key}")
return True
except Exception as e:
logger.warning(f"[SlotStateCache] Failed to set cache: {e}")
return False
async def merge_and_set(
self,
tenant_id: str,
session_id: str,
new_slots: dict[str, CachedSlotValue],
slot_to_field_map: dict[str, str] | None = None,
) -> CachedSlotState:
"""
Merge new slot values with cached state and save.
Priority: user_confirmed > rule_extracted > llm_inferred > context > default
Args:
tenant_id: Tenant ID
session_id: Session ID
new_slots: New slot values to merge
slot_to_field_map: Slot to field mapping
Returns:
Updated CachedSlotState
"""
state = await self.get(tenant_id, session_id)
if state is None:
state = CachedSlotState()
for slot_key, new_value in new_slots.items():
if slot_key in state.filled_slots:
existing = state.filled_slots[slot_key]
existing_priority = self._get_source_priority(existing.source)
new_priority = self._get_source_priority(new_value.source)
if new_priority >= existing_priority:
state.filled_slots[slot_key] = new_value
logger.debug(
f"[SlotStateCache] Slot '{slot_key}' updated: "
f"{existing.source}({existing_priority}) -> "
f"{new_value.source}({new_priority})"
)
else:
state.filled_slots[slot_key] = new_value
logger.debug(
f"[SlotStateCache] Slot '{slot_key}' added: "
f"source={new_value.source}, value={new_value.value}"
)
if slot_to_field_map:
state.slot_to_field_map.update(slot_to_field_map)
await self.set(tenant_id, session_id, state)
return state
async def delete(
self,
tenant_id: str,
session_id: str,
) -> bool:
"""
Delete slot state from cache (L1 + L2).
Args:
tenant_id: Tenant ID for isolation
session_id: Session ID
Returns:
True if successful
"""
local_key = self._make_local_key(tenant_id, session_id)
if local_key in self._local_cache:
del self._local_cache[local_key]
client = await self._get_client()
if client is None:
return False
key = self._make_key(tenant_id, session_id)
try:
await client.delete(key)
logger.debug(f"[SlotStateCache] Deleted cache: {key}")
return True
except Exception as e:
logger.warning(f"[SlotStateCache] Failed to delete cache: {e}")
return False
async def clear_slot(
self,
tenant_id: str,
session_id: str,
slot_key: str,
) -> bool:
"""
Clear a specific slot from cached state.
Args:
tenant_id: Tenant ID
session_id: Session ID
slot_key: Slot key to clear
Returns:
True if successful
"""
state = await self.get(tenant_id, session_id)
if state is None:
return True
if slot_key in state.filled_slots:
del state.filled_slots[slot_key]
await self.set(tenant_id, session_id, state)
logger.debug(f"[SlotStateCache] Cleared slot: {slot_key}")
return True
async def close(self) -> None:
"""Close Redis connection."""
if self._redis:
await self._redis.close()
_slot_state_cache: SlotStateCache | None = None
def get_slot_state_cache() -> SlotStateCache:
"""Get singleton SlotStateCache instance."""
global _slot_state_cache
if _slot_state_cache is None:
_slot_state_cache = SlotStateCache()
return _slot_state_cache

View File

@ -12,6 +12,9 @@ import logging
from pathlib import Path
from typing import Any
import redis
from app.core.config import get_settings
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
@ -20,6 +23,7 @@ from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
logger = logging.getLogger(__name__)
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
EMBEDDING_CONFIG_REDIS_KEY = "ai_service:config:embedding"
class EmbeddingProviderFactory:
@ -170,8 +174,32 @@ class EmbeddingConfigManager:
self._config = self._default_config.copy()
self._provider: EmbeddingProvider | None = None
self._settings = get_settings()
self._redis_client: redis.Redis | None = None
self._load_from_redis()
self._load_from_file()
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(EMBEDDING_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._provider_name = saved.get("provider", self._default_provider)
self._config = saved.get("config", self._default_config.copy())
logger.info(f"Loaded embedding config from Redis: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to load embedding config from Redis: {e}")
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
@ -184,6 +212,28 @@ class EmbeddingConfigManager:
except Exception as e:
logger.warning(f"Failed to load embedding config from file: {e}")
def _save_to_redis(self) -> None:
"""Save configuration to Redis."""
try:
if not self._settings.redis_enabled:
return
if self._redis_client is None:
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
self._redis_client.set(
EMBEDDING_CONFIG_REDIS_KEY,
json.dumps({
"provider": self._provider_name,
"config": self._config,
}, ensure_ascii=False),
)
logger.info(f"Saved embedding config to Redis: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to save embedding config to Redis: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
@ -262,6 +312,7 @@ class EmbeddingConfigManager:
self._config = config
self._provider = new_provider_instance
self._save_to_redis()
self._save_to_file()
logger.info(f"Updated embedding config: provider={provider}")

View File

@ -322,7 +322,7 @@ class FlowEngine:
stmt = select(FlowInstance).where(
FlowInstance.tenant_id == tenant_id,
FlowInstance.session_id == session_id,
).order_by(col(FlowInstance.created_at).desc())
).order_by(col(FlowInstance.started_at).desc())
result = await self._session.execute(stmt)
instance = result.scalar_one_or_none()

View File

@ -0,0 +1,385 @@
"""
Clarification mechanism for intent recognition.
[AC-CLARIFY] 澄清机制实现
核心功能
1. 统一置信度计算
2. 硬拦截规则confidence检查required_slots检查
3. 澄清状态管理
4. 埋点指标收集
"""
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import uuid
logger = logging.getLogger(__name__)
T_HIGH = 0.75
T_LOW = 0.45
MAX_CLARIFY_RETRY = 3
class ClarifyReason(str, Enum):
INTENT_AMBIGUITY = "intent_ambiguity"
MISSING_SLOT = "missing_slot"
LOW_CONFIDENCE = "low_confidence"
MULTI_INTENT = "multi_intent"
class ClarifyMetrics:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._clarify_trigger_count = 0
cls._instance._clarify_converge_count = 0
cls._instance._misroute_count = 0
return cls._instance
def record_clarify_trigger(self) -> None:
self._clarify_trigger_count += 1
logger.debug(f"[AC-CLARIFY-METRICS] clarify_trigger_count: {self._clarify_trigger_count}")
def record_clarify_converge(self) -> None:
self._clarify_converge_count += 1
logger.debug(f"[AC-CLARIFY-METRICS] clarify_converge_count: {self._clarify_converge_count}")
def record_misroute(self) -> None:
self._misroute_count += 1
logger.debug(f"[AC-CLARIFY-METRICS] misroute_count: {self._misroute_count}")
def get_metrics(self) -> dict[str, int]:
return {
"clarify_trigger_rate": self._clarify_trigger_count,
"clarify_converge_rate": self._clarify_converge_count,
"misroute_rate": self._misroute_count,
}
def get_rates(self, total_requests: int) -> dict[str, float]:
if total_requests == 0:
return {
"clarify_trigger_rate": 0.0,
"clarify_converge_rate": 0.0,
"misroute_rate": 0.0,
}
return {
"clarify_trigger_rate": self._clarify_trigger_count / total_requests,
"clarify_converge_rate": self._clarify_converge_count / total_requests if self._clarify_trigger_count > 0 else 0.0,
"misroute_rate": self._misroute_count / total_requests,
}
def reset(self) -> None:
self._clarify_trigger_count = 0
self._clarify_converge_count = 0
self._misroute_count = 0
def get_clarify_metrics() -> ClarifyMetrics:
return ClarifyMetrics()
@dataclass
class IntentCandidate:
intent_id: str
intent_name: str
confidence: float
response_type: str | None = None
target_kb_ids: list[str] | None = None
flow_id: str | None = None
fixed_reply: str | None = None
transfer_message: str | None = None
def to_dict(self) -> dict[str, Any]:
return {
"intent_id": self.intent_id,
"intent_name": self.intent_name,
"confidence": self.confidence,
"response_type": self.response_type,
"target_kb_ids": self.target_kb_ids,
"flow_id": self.flow_id,
"fixed_reply": self.fixed_reply,
"transfer_message": self.transfer_message,
}
@dataclass
class HybridIntentResult:
intent: IntentCandidate | None
confidence: float
candidates: list[IntentCandidate] = field(default_factory=list)
need_clarify: bool = False
clarify_reason: ClarifyReason | None = None
missing_slots: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"intent": self.intent.to_dict() if self.intent else None,
"confidence": self.confidence,
"candidates": [c.to_dict() for c in self.candidates],
"need_clarify": self.need_clarify,
"clarify_reason": self.clarify_reason.value if self.clarify_reason else None,
"missing_slots": self.missing_slots,
}
@classmethod
def from_fusion_result(cls, fusion_result: Any) -> "HybridIntentResult":
candidates = []
if fusion_result.clarify_candidates:
for c in fusion_result.clarify_candidates:
candidates.append(IntentCandidate(
intent_id=str(c.id),
intent_name=c.name,
confidence=0.0,
response_type=getattr(c, "response_type", None),
target_kb_ids=getattr(c, "target_kb_ids", None),
flow_id=str(c.flow_id) if getattr(c, "flow_id", None) else None,
fixed_reply=getattr(c, "fixed_reply", None),
transfer_message=getattr(c, "transfer_message", None),
))
if fusion_result.final_intent:
final_candidate = IntentCandidate(
intent_id=str(fusion_result.final_intent.id),
intent_name=fusion_result.final_intent.name,
confidence=fusion_result.final_confidence,
response_type=fusion_result.final_intent.response_type,
target_kb_ids=fusion_result.final_intent.target_kb_ids,
flow_id=str(fusion_result.final_intent.flow_id) if fusion_result.final_intent.flow_id else None,
fixed_reply=fusion_result.final_intent.fixed_reply,
transfer_message=fusion_result.final_intent.transfer_message,
)
if not any(c.intent_id == final_candidate.intent_id for c in candidates):
candidates.insert(0, final_candidate)
clarify_reason = None
if fusion_result.need_clarify:
if fusion_result.decision_reason == "multi_intent":
clarify_reason = ClarifyReason.MULTI_INTENT
elif fusion_result.decision_reason == "gray_zone":
clarify_reason = ClarifyReason.INTENT_AMBIGUITY
else:
clarify_reason = ClarifyReason.LOW_CONFIDENCE
return cls(
intent=candidates[0] if candidates else None,
confidence=fusion_result.final_confidence,
candidates=candidates,
need_clarify=fusion_result.need_clarify,
clarify_reason=clarify_reason,
)
@dataclass
class ClarifyState:
reason: ClarifyReason
asked_slot: str | None = None
retry_count: int = 0
candidates: list[IntentCandidate] = field(default_factory=list)
asked_intent_ids: list[str] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
def to_dict(self) -> dict[str, Any]:
return {
"reason": self.reason.value,
"asked_slot": self.asked_slot,
"retry_count": self.retry_count,
"candidates": [c.to_dict() for c in self.candidates],
"asked_intent_ids": self.asked_intent_ids,
"created_at": self.created_at,
}
def increment_retry(self) -> "ClarifyState":
self.retry_count += 1
return self
def is_max_retry(self) -> bool:
return self.retry_count >= MAX_CLARIFY_RETRY
class ClarificationEngine:
def __init__(
self,
t_high: float = T_HIGH,
t_low: float = T_LOW,
max_retry: int = MAX_CLARIFY_RETRY,
):
self._t_high = t_high
self._t_low = t_low
self._max_retry = max_retry
self._metrics = get_clarify_metrics()
def compute_confidence(
self,
rule_score: float = 0.0,
semantic_score: float = 0.0,
llm_score: float = 0.0,
w_rule: float = 0.5,
w_semantic: float = 0.3,
w_llm: float = 0.2,
) -> float:
total_weight = w_rule + w_semantic + w_llm
if total_weight == 0:
return 0.0
weighted_score = (
rule_score * w_rule +
semantic_score * w_semantic +
llm_score * w_llm
)
return min(1.0, max(0.0, weighted_score / total_weight))
def check_hard_block(
self,
result: HybridIntentResult,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> tuple[bool, ClarifyReason | None]:
if result.confidence < self._t_high:
return True, ClarifyReason.LOW_CONFIDENCE
if required_slots and filled_slots is not None:
missing = [s for s in required_slots if s not in filled_slots]
if missing:
return True, ClarifyReason.MISSING_SLOT
return False, None
def should_trigger_clarify(
self,
result: HybridIntentResult,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> tuple[bool, ClarifyState | None]:
if result.confidence >= self._t_high:
if required_slots and filled_slots is not None:
missing = [s for s in required_slots if s not in filled_slots]
if missing:
self._metrics.record_clarify_trigger()
return True, ClarifyState(
reason=ClarifyReason.MISSING_SLOT,
asked_slot=missing[0],
candidates=result.candidates,
)
return False, None
if result.confidence < self._t_low:
self._metrics.record_clarify_trigger()
return True, ClarifyState(
reason=ClarifyReason.LOW_CONFIDENCE,
candidates=result.candidates,
)
self._metrics.record_clarify_trigger()
reason = result.clarify_reason or ClarifyReason.INTENT_AMBIGUITY
return True, ClarifyState(
reason=reason,
candidates=result.candidates,
)
def generate_clarify_prompt(
self,
state: ClarifyState,
slot_label: str | None = None,
) -> str:
if state.reason == ClarifyReason.MISSING_SLOT:
slot_name = slot_label or state.asked_slot or "相关信息"
return f"为了更好地为您服务,请告诉我您的{slot_name}"
if state.reason == ClarifyReason.LOW_CONFIDENCE:
return "抱歉,我不太理解您的意思,能否请您详细描述一下您的需求?"
if state.reason == ClarifyReason.MULTI_INTENT and len(state.candidates) > 1:
candidates = state.candidates[:3]
if len(candidates) == 2:
return (
f"请问您是想「{candidates[0].intent_name}"
f"还是「{candidates[1].intent_name}」?"
)
else:
options = "".join([f"{c.intent_name}" for c in candidates[:-1]])
return f"请问您是想{options},还是「{candidates[-1].intent_name}」?"
if state.reason == ClarifyReason.INTENT_AMBIGUITY and len(state.candidates) > 1:
candidates = state.candidates[:2]
return (
f"请问您是想「{candidates[0].intent_name}"
f"还是「{candidates[1].intent_name}」?"
)
return "请问您具体想了解什么?"
def process_clarify_response(
self,
user_message: str,
state: ClarifyState,
intent_router: Any = None,
rules: list[Any] | None = None,
) -> HybridIntentResult:
state.increment_retry()
if state.is_max_retry():
self._metrics.record_misroute()
return HybridIntentResult(
intent=None,
confidence=0.0,
need_clarify=False,
)
if state.reason == ClarifyReason.MISSING_SLOT:
self._metrics.record_clarify_converge()
return HybridIntentResult(
intent=state.candidates[0] if state.candidates else None,
confidence=0.8,
candidates=state.candidates,
need_clarify=False,
)
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=state.candidates,
need_clarify=True,
clarify_reason=state.reason,
)
def get_metrics(self) -> dict[str, int]:
return self._metrics.get_metrics()
def get_rates(self, total_requests: int) -> dict[str, float]:
return self._metrics.get_rates(total_requests)
class ClarifySessionManager:
_sessions: dict[str, ClarifyState] = {}
@classmethod
def get_session(cls, session_id: str) -> ClarifyState | None:
return cls._sessions.get(session_id)
@classmethod
def set_session(cls, session_id: str, state: ClarifyState) -> None:
cls._sessions[session_id] = state
logger.debug(f"[AC-CLARIFY] Session state set: session={session_id}, reason={state.reason}")
@classmethod
def clear_session(cls, session_id: str) -> None:
if session_id in cls._sessions:
del cls._sessions[session_id]
logger.debug(f"[AC-CLARIFY] Session state cleared: session={session_id}")
@classmethod
def has_active_clarify(cls, session_id: str) -> bool:
state = cls._sessions.get(session_id)
if state:
return not state.is_max_retry()
return False

View File

@ -0,0 +1,254 @@
"""
[v0.8.0] Fusion policy for hybrid intent routing.
[AC-AISVC-115~AC-AISVC-117] Fusion decision logic for three-way matching.
"""
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING
from app.services.intent.models import (
FusionConfig,
FusionResult,
LlmJudgeResult,
RouteTrace,
RuleMatchResult,
SemanticMatchResult,
)
if TYPE_CHECKING:
from app.models.entities import IntentRule
logger = logging.getLogger(__name__)
DecisionCondition = Callable[[RuleMatchResult, SemanticMatchResult, LlmJudgeResult], bool]
class FusionPolicy:
"""
[AC-AISVC-115] Fusion decision policy for hybrid routing.
Decision priority:
1. rule_high_confidence: RuleMatcher hit with score=1.0
2. llm_judge: LlmJudge triggered and returned valid intent
3. semantic_override: RuleMatcher missed but SemanticMatcher high confidence
4. rule_semantic_agree: Both match same intent
5. semantic_fallback: SemanticMatcher medium confidence
6. rule_fallback: Only rule matched
7. no_match: All low confidence
"""
DECISION_PRIORITY: list[tuple[str, DecisionCondition]] = [
("rule_high_confidence", lambda r, s, llm: r.score == 1.0 and r.rule is not None),
("llm_judge", lambda r, s, llm: llm.triggered and llm.intent_id is not None),
(
"semantic_override",
lambda r, s, llm: r.score == 0
and s.top_score > 0.7
and not s.skipped
and len(s.candidates) > 0,
),
(
"rule_semantic_agree",
lambda r, s, llm: r.score > 0
and s.top_score > 0.5
and not s.skipped
and len(s.candidates) > 0
and r.rule_id == s.candidates[0].rule.id,
),
(
"semantic_fallback",
lambda r, s, llm: s.top_score > 0.5 and not s.skipped and len(s.candidates) > 0,
),
("rule_fallback", lambda r, s, llm: r.score > 0),
("no_match", lambda r, s, llm: True),
]
def __init__(self, config: FusionConfig | None = None):
"""
Initialize fusion policy with configuration.
Args:
config: Fusion configuration, uses default if not provided
"""
self._config = config or FusionConfig()
def fuse(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> FusionResult:
"""
[AC-AISVC-115] Execute fusion decision.
Args:
rule_result: Rule matching result
semantic_result: Semantic matching result
llm_result: LLM judge result (may be None)
Returns:
FusionResult with final intent, confidence, and trace
"""
trace = self._build_trace(rule_result, semantic_result, llm_result)
final_intent = None
final_confidence = 0.0
decision_reason = "no_match"
effective_llm_result = llm_result or LlmJudgeResult.empty()
for reason, condition in self.DECISION_PRIORITY:
if condition(rule_result, semantic_result, effective_llm_result):
decision_reason = reason
break
if decision_reason == "rule_high_confidence":
final_intent = rule_result.rule
final_confidence = 1.0
elif decision_reason == "llm_judge" and llm_result:
final_intent = self._find_rule_by_id(
llm_result.intent_id, rule_result, semantic_result
)
final_confidence = llm_result.score
elif decision_reason == "semantic_override":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
elif decision_reason == "rule_semantic_agree":
final_intent = rule_result.rule
final_confidence = self._calculate_weighted_confidence(
rule_result, semantic_result, llm_result
)
elif decision_reason == "semantic_fallback":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
elif decision_reason == "rule_fallback":
final_intent = rule_result.rule
final_confidence = rule_result.score
need_clarify = final_confidence < self._config.clarify_threshold
clarify_candidates = None
if need_clarify and len(semantic_result.candidates) > 1:
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
trace.fusion = {
"weights": {
"w_rule": self._config.w_rule,
"w_semantic": self._config.w_semantic,
"w_llm": self._config.w_llm,
},
"final_confidence": final_confidence,
"decision_reason": decision_reason,
"need_clarify": need_clarify,
}
logger.info(
f"[AC-AISVC-115] Fusion decision: reason={decision_reason}, "
f"confidence={final_confidence:.3f}, need_clarify={need_clarify}"
)
return FusionResult(
final_intent=final_intent,
final_confidence=final_confidence,
decision_reason=decision_reason,
need_clarify=need_clarify,
clarify_candidates=clarify_candidates,
trace=trace,
)
def _build_trace(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> RouteTrace:
"""
[AC-AISVC-122] Build route trace log.
"""
return RouteTrace(
rule_match={
"rule_id": str(rule_result.rule_id) if rule_result.rule_id else None,
"rule_name": rule_result.rule.name if rule_result.rule else None,
"match_type": rule_result.match_type,
"matched_text": rule_result.matched_text,
"score": rule_result.score,
"duration_ms": rule_result.duration_ms,
},
semantic_match={
"top_candidates": [
{
"rule_id": str(c.rule.id),
"rule_name": c.rule.name,
"score": c.score,
}
for c in semantic_result.candidates
],
"top_score": semantic_result.top_score,
"duration_ms": semantic_result.duration_ms,
"skipped": semantic_result.skipped,
"skip_reason": semantic_result.skip_reason,
},
llm_judge={
"triggered": llm_result.triggered if llm_result else False,
"intent_id": llm_result.intent_id if llm_result else None,
"intent_name": llm_result.intent_name if llm_result else None,
"score": llm_result.score if llm_result else 0.0,
"reasoning": llm_result.reasoning if llm_result else None,
"duration_ms": llm_result.duration_ms if llm_result else 0,
"tokens_used": llm_result.tokens_used if llm_result else 0,
},
fusion={},
)
def _calculate_weighted_confidence(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> float:
"""
[AC-AISVC-116] Calculate weighted confidence.
Formula:
final_confidence = (w_rule * rule_score + w_semantic * semantic_score + w_llm * llm_score) / total_weight
Returns:
Weighted confidence in [0.0, 1.0]
"""
rule_score = rule_result.score
semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0
llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0
total_weight = self._config.w_rule + self._config.w_semantic
if llm_result and llm_result.triggered:
total_weight += self._config.w_llm
if total_weight == 0:
return 0.0
confidence = (
self._config.w_rule * rule_score
+ self._config.w_semantic * semantic_score
+ self._config.w_llm * llm_score
) / total_weight
return min(1.0, max(0.0, confidence))
def _find_rule_by_id(
self,
intent_id: str | None,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
) -> "IntentRule | None":
"""Find rule by ID from rule or semantic results."""
if not intent_id:
return None
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
return rule_result.rule
for candidate in semantic_result.candidates:
if str(candidate.rule.id) == intent_id:
return candidate.rule
return None

View File

@ -0,0 +1,246 @@
"""
LLM judge for intent arbitration.
[AC-AISVC-118, AC-AISVC-119] LLM-based intent arbitration.
"""
import asyncio
import json
import logging
import time
from typing import TYPE_CHECKING, Any
from app.services.intent.models import (
FusionConfig,
LlmJudgeInput,
LlmJudgeResult,
RuleMatchResult,
SemanticMatchResult,
)
if TYPE_CHECKING:
from app.services.llm.base import LLMClient
logger = logging.getLogger(__name__)
class LlmJudge:
"""
[AC-AISVC-118] LLM-based intent arbitrator.
Triggered when:
- Rule vs Semantic conflict
- Gray zone (low confidence)
- Multiple intent candidates with similar scores
"""
JUDGE_PROMPT = """你是一个意图识别仲裁器。根据用户消息和候选意图,判断最匹配的意图。
用户消息{message}
候选意图
{candidates}
请返回 JSON 格式不要包含```json标记
{{
"intent_id": "最匹配的意图ID",
"intent_name": "意图名称",
"confidence": 0.0-1.0之间的置信度,
"reasoning": "判断理由"
}}"""
def __init__(
self,
llm_client: "LLMClient",
config: FusionConfig,
):
"""
Initialize LLM judge.
Args:
llm_client: LLM client for generating responses
config: Fusion configuration
"""
self._llm_client = llm_client
self._config = config
def should_trigger(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
config: FusionConfig | None = None,
) -> tuple[bool, str]:
"""
[AC-AISVC-118] Check if LLM judge should be triggered.
Trigger conditions:
1. Conflict: Rule and Semantic match different intents with close scores
2. Gray zone: Max confidence in gray zone range
3. Multi-intent: Multiple candidates with similar scores
Args:
rule_result: Rule matching result
semantic_result: Semantic matching result
config: Optional config override
Returns:
Tuple of (should_trigger, trigger_reason)
"""
effective_config = config or self._config
if not effective_config.llm_judge_enabled:
return False, "disabled"
rule_score = rule_result.score
semantic_score = semantic_result.top_score
if rule_score > 0 and semantic_score > 0:
if semantic_result.candidates:
top_semantic_rule_id = semantic_result.candidates[0].rule.id
if rule_result.rule_id != top_semantic_rule_id:
if abs(rule_score - semantic_score) < effective_config.conflict_threshold:
logger.info(
f"[AC-AISVC-118] LLM judge triggered: rule_semantic_conflict, "
f"rule_id={rule_result.rule_id}, semantic_id={top_semantic_rule_id}, "
f"rule_score={rule_score}, semantic_score={semantic_score}"
)
return True, "rule_semantic_conflict"
max_score = max(rule_score, semantic_score)
if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold:
logger.info(
f"[AC-AISVC-118] LLM judge triggered: gray_zone, "
f"max_score={max_score}"
)
return True, "gray_zone"
if len(semantic_result.candidates) >= 2:
top1_score = semantic_result.candidates[0].score
top2_score = semantic_result.candidates[1].score
if abs(top1_score - top2_score) < effective_config.multi_intent_threshold:
logger.info(
f"[AC-AISVC-118] LLM judge triggered: multi_intent, "
f"top1_score={top1_score}, top2_score={top2_score}"
)
return True, "multi_intent"
return False, ""
async def judge(
self,
input_data: LlmJudgeInput,
tenant_id: str,
) -> LlmJudgeResult:
"""
[AC-AISVC-119] Perform LLM arbitration.
Args:
input_data: Judge input with message and candidates
tenant_id: Tenant ID for isolation
Returns:
LlmJudgeResult with arbitration decision
"""
start_time = time.time()
candidates_text = "\n".join([
f"- ID: {c['id']}, 名称: {c['name']}, 描述: {c.get('description', 'N/A')}"
for c in input_data.candidates
])
prompt = self.JUDGE_PROMPT.format(
message=input_data.message,
candidates=candidates_text,
)
try:
from app.services.llm.base import LLMConfig
response = await asyncio.wait_for(
self._llm_client.generate(
messages=[{"role": "user", "content": prompt}],
config=LLMConfig(
max_tokens=200,
temperature=0,
),
),
timeout=self._config.llm_judge_timeout_ms / 1000,
)
result = self._parse_response(response.content or "")
duration_ms = int((time.time() - start_time) * 1000)
tokens_used = 0
if response.usage:
tokens_used = response.usage.get("total_tokens", 0)
logger.info(
f"[AC-AISVC-119] LLM judge completed for tenant={tenant_id}, "
f"intent_id={result.get('intent_id')}, confidence={result.get('confidence', 0):.3f}, "
f"duration={duration_ms}ms, tokens={tokens_used}"
)
return LlmJudgeResult(
intent_id=result.get("intent_id"),
intent_name=result.get("intent_name"),
score=float(result.get("confidence", 0.5)),
reasoning=result.get("reasoning"),
duration_ms=duration_ms,
tokens_used=tokens_used,
triggered=True,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-AISVC-119] LLM judge timeout for tenant={tenant_id}, "
f"timeout={self._config.llm_judge_timeout_ms}ms"
)
return LlmJudgeResult(
intent_id=None,
intent_name=None,
score=0.0,
reasoning="LLM timeout",
duration_ms=duration_ms,
tokens_used=0,
triggered=True,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-AISVC-119] LLM judge error for tenant={tenant_id}: {e}"
)
return LlmJudgeResult(
intent_id=None,
intent_name=None,
score=0.0,
reasoning=f"LLM error: {str(e)}",
duration_ms=duration_ms,
tokens_used=0,
triggered=True,
)
def _parse_response(self, content: str) -> dict[str, Any]:
"""
Parse LLM response to extract JSON result.
Args:
content: LLM response content
Returns:
Parsed dictionary with intent_id, intent_name, confidence, reasoning
"""
try:
cleaned = content.strip()
if cleaned.startswith("```json"):
cleaned = cleaned[7:]
if cleaned.startswith("```"):
cleaned = cleaned[3:]
if cleaned.endswith("```"):
cleaned = cleaned[:-3]
cleaned = cleaned.strip()
result: dict[str, Any] = json.loads(cleaned)
return result
except json.JSONDecodeError as e:
logger.warning(f"[AC-AISVC-119] Failed to parse LLM response: {e}")
return {}

View File

@ -0,0 +1,226 @@
"""
Intent routing data models.
[AC-AISVC-111~AC-AISVC-125] Data models for hybrid routing.
"""
import uuid
from dataclasses import dataclass, field
from typing import Any
@dataclass
class RuleMatchResult:
"""
[AC-AISVC-112] Result of rule matching.
Contains matched rule and score.
"""
rule_id: uuid.UUID | None
rule: Any | None
match_type: str | None
matched_text: str | None
score: float
duration_ms: int
def to_dict(self) -> dict[str, Any]:
return {
"rule_id": str(self.rule_id) if self.rule_id else None,
"rule_name": self.rule.name if self.rule else None,
"match_type": self.match_type,
"matched_text": self.matched_text,
"score": self.score,
"duration_ms": self.duration_ms,
}
@dataclass
class SemanticCandidate:
"""
[AC-AISVC-113] Semantic match candidate.
"""
rule: Any
score: float
def to_dict(self) -> dict[str, Any]:
return {
"rule_id": str(self.rule.id),
"rule_name": self.rule.name,
"score": self.score,
}
@dataclass
class SemanticMatchResult:
"""
[AC-AISVC-113] Result of semantic matching.
"""
candidates: list[SemanticCandidate]
top_score: float
duration_ms: int
skipped: bool
skip_reason: str | None
def to_dict(self) -> dict[str, Any]:
return {
"top_candidates": [c.to_dict() for c in self.candidates],
"top_score": self.top_score,
"duration_ms": self.duration_ms,
"skipped": self.skipped,
"skip_reason": self.skip_reason,
}
@dataclass
class LlmJudgeInput:
"""
[AC-AISVC-119] Input for LLM judge.
"""
message: str
candidates: list[dict[str, Any]]
conflict_type: str
@dataclass
class LlmJudgeResult:
"""
[AC-AISVC-119] Result of LLM judge.
"""
intent_id: str | None
intent_name: str | None
score: float
reasoning: str | None
duration_ms: int
tokens_used: int
triggered: bool
def to_dict(self) -> dict[str, Any]:
return {
"triggered": self.triggered,
"intent_id": self.intent_id,
"intent_name": self.intent_name,
"score": self.score,
"reasoning": self.reasoning,
"duration_ms": self.duration_ms,
"tokens_used": self.tokens_used,
}
@classmethod
def empty(cls) -> "LlmJudgeResult":
return cls(
intent_id=None,
intent_name=None,
score=0.0,
reasoning=None,
duration_ms=0,
tokens_used=0,
triggered=False,
)
@dataclass
class FusionConfig:
"""
[AC-AISVC-116] Fusion configuration.
"""
w_rule: float = 0.5
w_semantic: float = 0.3
w_llm: float = 0.2
semantic_threshold: float = 0.7
conflict_threshold: float = 0.2
gray_zone_threshold: float = 0.6
min_trigger_threshold: float = 0.3
clarify_threshold: float = 0.4
multi_intent_threshold: float = 0.15
llm_judge_enabled: bool = True
semantic_matcher_enabled: bool = True
semantic_matcher_timeout_ms: int = 100
llm_judge_timeout_ms: int = 2000
semantic_top_k: int = 3
def to_dict(self) -> dict[str, Any]:
return {
"w_rule": self.w_rule,
"w_semantic": self.w_semantic,
"w_llm": self.w_llm,
"semantic_threshold": self.semantic_threshold,
"conflict_threshold": self.conflict_threshold,
"gray_zone_threshold": self.gray_zone_threshold,
"min_trigger_threshold": self.min_trigger_threshold,
"clarify_threshold": self.clarify_threshold,
"multi_intent_threshold": self.multi_intent_threshold,
"llm_judge_enabled": self.llm_judge_enabled,
"semantic_matcher_enabled": self.semantic_matcher_enabled,
"semantic_matcher_timeout_ms": self.semantic_matcher_timeout_ms,
"llm_judge_timeout_ms": self.llm_judge_timeout_ms,
"semantic_top_k": self.semantic_top_k,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "FusionConfig":
return cls(
w_rule=data.get("w_rule", 0.5),
w_semantic=data.get("w_semantic", 0.3),
w_llm=data.get("w_llm", 0.2),
semantic_threshold=data.get("semantic_threshold", 0.7),
conflict_threshold=data.get("conflict_threshold", 0.2),
gray_zone_threshold=data.get("gray_zone_threshold", 0.6),
min_trigger_threshold=data.get("min_trigger_threshold", 0.3),
clarify_threshold=data.get("clarify_threshold", 0.4),
multi_intent_threshold=data.get("multi_intent_threshold", 0.15),
llm_judge_enabled=data.get("llm_judge_enabled", True),
semantic_matcher_enabled=data.get("semantic_matcher_enabled", True),
semantic_matcher_timeout_ms=data.get("semantic_matcher_timeout_ms", 100),
llm_judge_timeout_ms=data.get("llm_judge_timeout_ms", 2000),
semantic_top_k=data.get("semantic_top_k", 3),
)
@dataclass
class RouteTrace:
"""
[AC-AISVC-122] Route trace log.
"""
rule_match: dict[str, Any] = field(default_factory=dict)
semantic_match: dict[str, Any] = field(default_factory=dict)
llm_judge: dict[str, Any] = field(default_factory=dict)
fusion: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"rule_match": self.rule_match,
"semantic_match": self.semantic_match,
"llm_judge": self.llm_judge,
"fusion": self.fusion,
}
@dataclass
class FusionResult:
"""
[AC-AISVC-115] Fusion decision result.
"""
final_intent: Any | None
final_confidence: float
decision_reason: str
need_clarify: bool
clarify_candidates: list[Any] | None
trace: RouteTrace
def to_dict(self) -> dict[str, Any]:
return {
"final_intent": {
"id": str(self.final_intent.id),
"name": self.final_intent.name,
"response_type": self.final_intent.response_type,
} if self.final_intent else None,
"final_confidence": self.final_confidence,
"decision_reason": self.decision_reason,
"need_clarify": self.need_clarify,
"clarify_candidates": [
{"id": str(c.id), "name": c.name}
for c in (self.clarify_candidates or [])
],
"trace": self.trace.to_dict(),
}
DEFAULT_FUSION_CONFIG = FusionConfig()

View File

@ -1,14 +1,30 @@
"""
Intent router for AI Service.
[AC-AISVC-69, AC-AISVC-70] Intent matching engine with keyword and regex support.
[v0.8.0] Upgraded to hybrid routing with RuleMatcher + SemanticMatcher + LlmJudge + FusionPolicy.
"""
import logging
import re
import time
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING, Any
from app.models.entities import IntentRule
from app.services.intent.models import (
FusionConfig,
FusionResult,
LlmJudgeInput,
LlmJudgeResult,
RouteTrace,
RuleMatchResult,
SemanticMatchResult,
)
if TYPE_CHECKING:
from app.services.intent.fusion_policy import FusionPolicy
from app.services.intent.llm_judge import LlmJudge
from app.services.intent.semantic_matcher import SemanticMatcher
logger = logging.getLogger(__name__)
@ -38,38 +54,36 @@ class IntentMatchResult:
}
class IntentRouter:
class RuleMatcher:
"""
[AC-AISVC-69] Intent matching engine.
Matching algorithm:
1. Load rules ordered by priority DESC
2. For each rule, try keyword matching first
3. If no keyword match, try regex pattern matching
4. Return first match (highest priority)
5. If no match, return None (fallback to default RAG)
[v0.8.0] Rule matcher for keyword and regex matching.
Extracted from IntentRouter for hybrid routing.
"""
def __init__(self):
pass
def match(
self,
message: str,
rules: list[IntentRule],
) -> IntentMatchResult | None:
def match(self, message: str, rules: list[IntentRule]) -> RuleMatchResult:
"""
[AC-AISVC-69] Match user message against intent rules.
[AC-AISVC-112] Match user message against intent rules.
Returns RuleMatchResult with score (1.0 for match, 0.0 for no match).
Args:
message: User input message
rules: List of enabled rules ordered by priority DESC
Returns:
IntentMatchResult if matched, None otherwise
RuleMatchResult with match details
"""
start_time = time.time()
if not message or not rules:
return None
duration_ms = int((time.time() - start_time) * 1000)
return RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=duration_ms,
)
message_lower = message.lower()
@ -79,22 +93,46 @@ class IntentRouter:
keyword_result = self._match_keywords(message, message_lower, rule)
if keyword_result:
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-AISVC-69] Intent matched by keyword: "
f"rule={rule.name}, matched='{keyword_result.matched}'"
)
return keyword_result
return RuleMatchResult(
rule_id=rule.id,
rule=rule,
match_type="keyword",
matched_text=keyword_result.matched,
score=1.0,
duration_ms=duration_ms,
)
regex_result = self._match_patterns(message, rule)
if regex_result:
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-AISVC-69] Intent matched by regex: "
f"rule={rule.name}, matched='{regex_result.matched}'"
)
return regex_result
return RuleMatchResult(
rule_id=rule.id,
rule=rule,
match_type="regex",
matched_text=regex_result.matched,
score=1.0,
duration_ms=duration_ms,
)
duration_ms = int((time.time() - start_time) * 1000)
logger.debug("[AC-AISVC-70] No intent matched, will fallback to default RAG")
return None
return RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=duration_ms,
)
def _match_keywords(
self,
@ -153,6 +191,74 @@ class IntentRouter:
return None
class IntentRouter:
"""
[AC-AISVC-69] Intent matching engine.
[v0.8.0] Upgraded to support hybrid routing.
Matching algorithm:
1. Load rules ordered by priority DESC
2. For each rule, try keyword matching first
3. If no keyword match, try regex pattern matching
4. Return first match (highest priority)
5. If no match, return None (fallback to default RAG)
Hybrid routing (match_hybrid):
1. Parallel execute RuleMatcher + SemanticMatcher
2. Conditionally trigger LlmJudge
3. Execute FusionPolicy for final decision
"""
def __init__(
self,
rule_matcher: RuleMatcher | None = None,
semantic_matcher: "SemanticMatcher | None" = None,
llm_judge: "LlmJudge | None" = None,
fusion_policy: "FusionPolicy | None" = None,
config: FusionConfig | None = None,
):
"""
[v0.8.0] Initialize with optional dependencies for DI.
Args:
rule_matcher: Rule matcher for keyword/regex matching
semantic_matcher: Semantic matcher for vector similarity
llm_judge: LLM judge for arbitration
fusion_policy: Fusion policy for decision making
config: Fusion configuration
"""
self._rule_matcher = rule_matcher or RuleMatcher()
self._semantic_matcher = semantic_matcher
self._llm_judge = llm_judge
self._fusion_policy = fusion_policy
self._config = config or FusionConfig()
def match(
self,
message: str,
rules: list[IntentRule],
) -> IntentMatchResult | None:
"""
[AC-AISVC-69] Match user message against intent rules.
Preserved for backward compatibility.
Args:
message: User input message
rules: List of enabled rules ordered by priority DESC
Returns:
IntentMatchResult if matched, None otherwise
"""
result = self._rule_matcher.match(message, rules)
if result.rule:
return IntentMatchResult(
rule=result.rule,
match_type=result.match_type or "keyword",
matched=result.matched_text or "",
)
return None
def match_with_stats(
self,
message: str,
@ -168,3 +274,300 @@ class IntentRouter:
if result:
return result, str(result.rule.id)
return None, None
async def match_hybrid(
self,
message: str,
rules: list[IntentRule],
tenant_id: str,
config: FusionConfig | None = None,
) -> FusionResult:
"""
[AC-AISVC-111] Hybrid routing entry point.
Flow:
1. Parallel execute RuleMatcher + SemanticMatcher
2. Check if LlmJudge should trigger
3. Execute FusionPolicy for final decision
Args:
message: User input message
rules: List of enabled rules ordered by priority DESC
tenant_id: Tenant ID for isolation
config: Optional fusion config override
Returns:
FusionResult with final intent, confidence, and trace
"""
effective_config = config or self._config
start_time = time.time()
rule_result = self._rule_matcher.match(message, rules)
semantic_result = await self._execute_semantic_matcher(
message, rules, tenant_id, effective_config
)
llm_result = await self._conditionally_execute_llm_judge(
message, rule_result, semantic_result, tenant_id, effective_config
)
if self._fusion_policy:
fusion_result = self._fusion_policy.fuse(
rule_result, semantic_result, llm_result
)
else:
fusion_result = self._default_fusion(
rule_result, semantic_result, llm_result, effective_config
)
total_duration_ms = int((time.time() - start_time) * 1000)
fusion_result.trace.fusion["total_duration_ms"] = total_duration_ms
logger.info(
f"[AC-AISVC-111] Hybrid routing completed: "
f"decision={fusion_result.decision_reason}, "
f"confidence={fusion_result.final_confidence:.3f}, "
f"duration={total_duration_ms}ms"
)
return fusion_result
async def _execute_semantic_matcher(
self,
message: str,
rules: list[IntentRule],
tenant_id: str,
config: FusionConfig,
) -> SemanticMatchResult:
"""Execute semantic matcher if available and enabled."""
if not self._semantic_matcher:
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=0,
skipped=True,
skip_reason="not_configured",
)
if not config.semantic_matcher_enabled:
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=0,
skipped=True,
skip_reason="disabled",
)
try:
return await self._semantic_matcher.match(
message=message,
rules=rules,
tenant_id=tenant_id,
top_k=config.semantic_top_k,
)
except Exception as e:
logger.warning(f"[AC-AISVC-113] Semantic matcher failed: {e}")
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=0,
skipped=True,
skip_reason=f"error: {str(e)}",
)
async def _conditionally_execute_llm_judge(
self,
message: str,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
tenant_id: str,
config: FusionConfig,
) -> LlmJudgeResult | None:
"""Conditionally execute LLM judge based on trigger conditions."""
if not self._llm_judge:
return None
if not config.llm_judge_enabled:
return None
should_trigger, trigger_reason = self._check_llm_trigger(
rule_result, semantic_result, config
)
if not should_trigger:
return None
logger.info(f"[AC-AISVC-118] LLM judge triggered: reason={trigger_reason}")
candidates = self._build_llm_candidates(rule_result, semantic_result)
if not candidates:
return None
try:
return await self._llm_judge.judge(
LlmJudgeInput(
message=message,
candidates=candidates,
conflict_type=trigger_reason,
),
tenant_id,
)
except Exception as e:
logger.warning(f"[AC-AISVC-119] LLM judge failed: {e}")
return LlmJudgeResult(
intent_id=None,
intent_name=None,
score=0.0,
reasoning=f"LLM error: {str(e)}",
duration_ms=0,
tokens_used=0,
triggered=True,
)
def _check_llm_trigger(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
config: FusionConfig,
) -> tuple[bool, str]:
"""
[AC-AISVC-118] Check if LLM judge should trigger.
Trigger conditions:
1. Conflict: RuleMatcher and SemanticMatcher match different intents
2. Gray zone: Max confidence in gray zone range
3. Multi-intent: Multiple candidates with close scores
Returns:
(should_trigger, trigger_reason)
"""
rule_score = rule_result.score
semantic_score = semantic_result.top_score
if rule_score > 0 and semantic_score > 0 and not semantic_result.skipped:
if semantic_result.candidates:
top_semantic_rule_id = semantic_result.candidates[0].rule.id
if rule_result.rule_id != top_semantic_rule_id:
if abs(rule_score - semantic_score) < config.conflict_threshold:
return True, "rule_semantic_conflict"
max_score = max(rule_score, semantic_score)
if config.min_trigger_threshold < max_score < config.gray_zone_threshold:
return True, "gray_zone"
if len(semantic_result.candidates) >= 2:
top1_score = semantic_result.candidates[0].score
top2_score = semantic_result.candidates[1].score
if abs(top1_score - top2_score) < config.multi_intent_threshold:
return True, "multi_intent"
return False, ""
def _build_llm_candidates(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
) -> list[dict[str, Any]]:
"""Build candidate list for LLM judge."""
candidates = []
if rule_result.rule:
candidates.append({
"id": str(rule_result.rule_id),
"name": rule_result.rule.name,
"description": f"匹配方式: {rule_result.match_type}, 匹配内容: {rule_result.matched_text}",
})
for candidate in semantic_result.candidates[:3]:
if not any(c["id"] == str(candidate.rule.id) for c in candidates):
candidates.append({
"id": str(candidate.rule.id),
"name": candidate.rule.name,
"description": f"语义相似度: {candidate.score:.2f}",
})
return candidates
def _default_fusion(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
config: FusionConfig,
) -> FusionResult:
"""Default fusion logic when FusionPolicy is not available."""
trace = RouteTrace(
rule_match=rule_result.to_dict(),
semantic_match=semantic_result.to_dict(),
llm_judge=llm_result.to_dict() if llm_result else {},
fusion={},
)
final_intent = None
final_confidence = 0.0
decision_reason = "no_match"
if rule_result.score == 1.0 and rule_result.rule:
final_intent = rule_result.rule
final_confidence = 1.0
decision_reason = "rule_high_confidence"
elif llm_result and llm_result.triggered and llm_result.intent_id:
final_intent = self._find_rule_by_id(
llm_result.intent_id, rule_result, semantic_result
)
final_confidence = llm_result.score
decision_reason = "llm_judge"
elif rule_result.score == 0 and semantic_result.top_score > config.semantic_threshold:
if semantic_result.candidates:
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
decision_reason = "semantic_override"
elif semantic_result.top_score > 0.5:
if semantic_result.candidates:
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
decision_reason = "semantic_fallback"
need_clarify = final_confidence < config.clarify_threshold
clarify_candidates = None
if need_clarify and len(semantic_result.candidates) > 1:
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
trace.fusion = {
"weights": {
"w_rule": config.w_rule,
"w_semantic": config.w_semantic,
"w_llm": config.w_llm,
},
"final_confidence": final_confidence,
"decision_reason": decision_reason,
}
return FusionResult(
final_intent=final_intent,
final_confidence=final_confidence,
decision_reason=decision_reason,
need_clarify=need_clarify,
clarify_candidates=clarify_candidates,
trace=trace,
)
def _find_rule_by_id(
self,
intent_id: str | None,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
) -> IntentRule | None:
"""Find rule by ID from rule or semantic results."""
if not intent_id:
return None
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
return rule_result.rule
for candidate in semantic_result.candidates:
if str(candidate.rule.id) == intent_id:
return candidate.rule
return None

View File

@ -106,6 +106,8 @@ class IntentRuleService:
is_enabled=True,
hit_count=0,
metadata_=create_data.metadata_,
intent_vector=create_data.intent_vector,
semantic_examples=create_data.semantic_examples,
)
self._session.add(rule)
await self._session.flush()
@ -195,6 +197,10 @@ class IntentRuleService:
rule.is_enabled = update_data.is_enabled
if update_data.metadata_ is not None:
rule.metadata_ = update_data.metadata_
if update_data.intent_vector is not None:
rule.intent_vector = update_data.intent_vector
if update_data.semantic_examples is not None:
rule.semantic_examples = update_data.semantic_examples
rule.updated_at = datetime.utcnow()
await self._session.flush()
@ -267,7 +273,7 @@ class IntentRuleService:
select(IntentRule)
.where(
IntentRule.tenant_id == tenant_id,
IntentRule.is_enabled == True,
IntentRule.is_enabled == True, # noqa: E712
)
.order_by(col(IntentRule.priority).desc())
)
@ -300,6 +306,8 @@ class IntentRuleService:
"is_enabled": rule.is_enabled,
"hit_count": rule.hit_count,
"metadata": rule.metadata_,
"created_at": rule.created_at.isoformat(),
"updated_at": rule.updated_at.isoformat(),
"intent_vector": rule.intent_vector,
"semantic_examples": rule.semantic_examples,
"created_at": rule.created_at.isoformat() if rule.created_at else None,
"updated_at": rule.updated_at.isoformat() if rule.updated_at else None,
}

View File

@ -0,0 +1,233 @@
"""
Semantic matcher for intent recognition.
[AC-AISVC-113, AC-AISVC-114] Vector-based semantic matching.
"""
import asyncio
import logging
import time
from typing import TYPE_CHECKING
import numpy as np
from app.services.intent.models import (
FusionConfig,
SemanticCandidate,
SemanticMatchResult,
)
if TYPE_CHECKING:
from app.models.entities import IntentRule
from app.services.embedding.base import EmbeddingProvider
logger = logging.getLogger(__name__)
class SemanticMatcher:
"""
[AC-AISVC-113] Semantic matcher using vector similarity.
Supports two matching modes:
- Mode A: Use pre-computed intent_vector for direct similarity calculation
- Mode B: Use semantic_examples for dynamic vector computation
"""
def __init__(
self,
embedding_provider: "EmbeddingProvider",
config: FusionConfig,
):
"""
Initialize semantic matcher.
Args:
embedding_provider: Provider for generating embeddings
config: Fusion configuration
"""
self._embedding_provider = embedding_provider
self._config = config
async def match(
self,
message: str,
rules: list["IntentRule"],
tenant_id: str,
top_k: int | None = None,
) -> SemanticMatchResult:
"""
[AC-AISVC-113] Perform vector semantic matching.
Args:
message: User message
rules: List of intent rules
tenant_id: Tenant ID for isolation
top_k: Number of top candidates to return
Returns:
SemanticMatchResult with candidates and scores
"""
start_time = time.time()
effective_top_k = top_k or self._config.semantic_top_k
if not self._config.semantic_matcher_enabled:
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=0,
skipped=True,
skip_reason="disabled",
)
rules_with_semantic = [r for r in rules if self._has_semantic_config(r)]
if not rules_with_semantic:
duration_ms = int((time.time() - start_time) * 1000)
logger.debug(
f"[AC-AISVC-113] No rules with semantic config for tenant={tenant_id}"
)
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=duration_ms,
skipped=True,
skip_reason="no_semantic_config",
)
try:
message_vector = await asyncio.wait_for(
self._embedding_provider.embed(message),
timeout=self._config.semantic_matcher_timeout_ms / 1000,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-AISVC-113] Embedding timeout for tenant={tenant_id}, "
f"timeout={self._config.semantic_matcher_timeout_ms}ms"
)
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=duration_ms,
skipped=True,
skip_reason="embedding_timeout",
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-AISVC-113] Embedding error for tenant={tenant_id}: {e}"
)
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=duration_ms,
skipped=True,
skip_reason=f"embedding_error: {str(e)}",
)
candidates = []
for rule in rules_with_semantic:
try:
score = await self._calculate_similarity(message_vector, rule)
if score > 0:
candidates.append(SemanticCandidate(rule=rule, score=score))
except Exception as e:
logger.warning(
f"[AC-AISVC-114] Similarity calculation failed for rule={rule.id}: {e}"
)
continue
candidates.sort(key=lambda x: x.score, reverse=True)
candidates = candidates[:effective_top_k]
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-AISVC-113] Semantic match completed for tenant={tenant_id}, "
f"candidates={len(candidates)}, top_score={candidates[0].score if candidates else 0:.3f}, "
f"duration={duration_ms}ms"
)
return SemanticMatchResult(
candidates=candidates,
top_score=candidates[0].score if candidates else 0.0,
duration_ms=duration_ms,
skipped=False,
skip_reason=None,
)
def _has_semantic_config(self, rule: "IntentRule") -> bool:
"""
Check if rule has semantic configuration.
Args:
rule: Intent rule to check
Returns:
True if rule has intent_vector or semantic_examples
"""
return bool(rule.intent_vector) or bool(rule.semantic_examples)
async def _calculate_similarity(
self,
message_vector: list[float],
rule: "IntentRule",
) -> float:
"""
[AC-AISVC-114] Calculate similarity between message and rule.
Mode A: Use pre-computed intent_vector
Mode B: Use semantic_examples for dynamic computation
Args:
message_vector: Message embedding vector
rule: Intent rule with semantic config
Returns:
Similarity score (0.0 ~ 1.0)
"""
if rule.intent_vector:
return self._cosine_similarity(message_vector, rule.intent_vector)
elif rule.semantic_examples:
try:
example_vectors = await self._embedding_provider.embed_batch(
rule.semantic_examples
)
similarities = [
self._cosine_similarity(message_vector, v)
for v in example_vectors
]
return max(similarities) if similarities else 0.0
except Exception as e:
logger.warning(
f"[AC-AISVC-114] Failed to compute example vectors for rule={rule.id}: {e}"
)
return 0.0
return 0.0
def _cosine_similarity(
self,
v1: list[float],
v2: list[float],
) -> float:
"""
Calculate cosine similarity between two vectors.
Args:
v1: First vector
v2: Second vector
Returns:
Cosine similarity (0.0 ~ 1.0)
"""
if not v1 or not v2:
return 0.0
v1_arr = np.array(v1)
v2_arr = np.array(v2)
norm1 = np.linalg.norm(v1_arr)
norm2 = np.linalg.norm(v2_arr)
if norm1 == 0 or norm2 == 0:
return 0.0
similarity = float(np.dot(v1_arr, v2_arr) / (norm1 * norm2))
return max(0.0, min(1.0, similarity))

View File

@ -83,6 +83,7 @@ class KBService:
file_name: str,
file_content: bytes,
file_type: str | None = None,
metadata: dict | None = None,
) -> tuple[Document, IndexJob]:
"""
[AC-ASA-01] Upload document and create indexing job.
@ -108,6 +109,7 @@ class KBService:
file_size=len(file_content),
file_type=file_type,
status=DocumentStatus.PENDING.value,
doc_metadata=metadata,
)
self._session.add(document)

View File

@ -3,7 +3,14 @@ LLM Adapter module for AI Service.
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
"""
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
from app.services.llm.base import (
LLMClient,
LLMConfig,
LLMResponse,
LLMStreamChunk,
ToolCall,
ToolDefinition,
)
from app.services.llm.openai_client import OpenAIClient
__all__ = [
@ -12,4 +19,6 @@ __all__ = [
"LLMResponse",
"LLMStreamChunk",
"OpenAIClient",
"ToolCall",
"ToolDefinition",
]

View File

@ -28,18 +28,46 @@ class LLMConfig:
extra_params: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolCall:
"""
Represents a function call from the LLM.
Used in Function Calling mode.
"""
id: str
name: str
arguments: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
import json
return {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
}
}
@dataclass
class LLMResponse:
"""
Response from LLM generation.
[AC-AISVC-02] Contains generated content and metadata.
"""
content: str
model: str
content: str | None = None
model: str = ""
usage: dict[str, int] = field(default_factory=dict)
finish_reason: str = "stop"
tool_calls: list[ToolCall] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
return len(self.tool_calls) > 0
@dataclass
class LLMStreamChunk:
@ -50,9 +78,33 @@ class LLMStreamChunk:
delta: str
model: str
finish_reason: str | None = None
tool_calls_delta: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolDefinition:
"""
Tool definition for Function Calling.
Compatible with OpenAI/DeepSeek function calling format.
"""
name: str
description: str
parameters: dict[str, Any]
type: str = "function"
def to_openai_format(self) -> dict[str, Any]:
"""Convert to OpenAI tools format."""
return {
"type": self.type,
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
}
class LLMClient(ABC):
"""
Abstract base class for LLM clients.
@ -67,6 +119,8 @@ class LLMClient(ABC):
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
@ -76,10 +130,12 @@ class LLMClient(ABC):
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content and metadata.
LLMResponse with generated content, tool_calls, and metadata.
Raises:
LLMException: If generation fails.
@ -91,6 +147,8 @@ class LLMClient(ABC):
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
@ -100,6 +158,8 @@ class LLMClient(ABC):
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters.
Yields:

View File

@ -11,12 +11,16 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any
import redis
from app.core.config import get_settings
from app.services.llm.base import LLMClient, LLMConfig
from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__)
LLM_CONFIG_FILE = Path("config/llm_config.json")
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
@dataclass
@ -286,6 +290,8 @@ class LLMConfigManager:
from app.core.config import get_settings
settings = get_settings()
self._settings = settings
self._redis_client: redis.Redis | None = None
self._current_provider: str = settings.llm_provider
self._current_config: dict[str, Any] = {
@ -299,8 +305,75 @@ class LLMConfigManager:
}
self._client: LLMClient | None = None
self._load_from_redis()
self._load_from_file()
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
def _save_to_redis(self) -> None:
"""Save configuration to Redis."""
try:
if not self._settings.redis_enabled:
return
if self._redis_client is None:
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
self._redis_client.set(
LLM_CONFIG_REDIS_KEY,
json.dumps({
"provider": self._current_provider,
"config": self._current_config,
}, ensure_ascii=False),
)
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
@ -364,6 +437,7 @@ class LLMConfigManager:
self._current_provider = provider
self._current_config = validated_config
self._save_to_redis()
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")

View File

@ -22,7 +22,14 @@ from tenacity import (
from app.core.config import get_settings
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
from app.services.llm.base import (
LLMClient,
LLMConfig,
LLMResponse,
LLMStreamChunk,
ToolCall,
ToolDefinition,
)
logger = logging.getLogger(__name__)
@ -95,6 +102,8 @@ class OpenAIClient(LLMClient):
messages: list[dict[str, str]],
config: LLMConfig,
stream: bool = False,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Build request body for OpenAI API."""
@ -106,6 +115,13 @@ class OpenAIClient(LLMClient):
"top_p": config.top_p,
"stream": stream,
}
if tools:
body["tools"] = [tool.to_openai_format() for tool in tools]
if tool_choice:
body["tool_choice"] = tool_choice
body.update(config.extra_params)
body.update(kwargs)
return body
@ -119,6 +135,8 @@ class OpenAIClient(LLMClient):
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
@ -128,10 +146,12 @@ class OpenAIClient(LLMClient):
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content and metadata.
LLMResponse with generated content, tool_calls, and metadata.
Raises:
LLMException: If generation fails.
@ -140,9 +160,14 @@ class OpenAIClient(LLMClient):
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
body = self._build_request_body(
messages, effective_config, stream=False,
tools=tools, tool_choice=tool_choice, **kwargs
)
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
if tools:
logger.info(f"[AC-AISVC-02] Function calling enabled with {len(tools)} tools")
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
@ -177,14 +202,18 @@ class OpenAIClient(LLMClient):
try:
choice = data["choices"][0]
content = choice["message"]["content"]
message = choice["message"]
content = message.get("content")
usage = data.get("usage", {})
finish_reason = choice.get("finish_reason", "stop")
tool_calls = self._parse_tool_calls(message)
logger.info(
f"[AC-AISVC-02] Generated response: "
f"tokens={usage.get('total_tokens', 'N/A')}, "
f"finish_reason={finish_reason}"
f"finish_reason={finish_reason}, "
f"tool_calls={len(tool_calls)}"
)
return LLMResponse(
@ -192,6 +221,7 @@ class OpenAIClient(LLMClient):
model=data.get("model", effective_config.model),
usage=usage,
finish_reason=finish_reason,
tool_calls=tool_calls,
metadata={"raw_response": data},
)
@ -202,10 +232,33 @@ class OpenAIClient(LLMClient):
details=[{"response": str(data)}],
)
def _parse_tool_calls(self, message: dict[str, Any]) -> list[ToolCall]:
"""Parse tool calls from LLM response message."""
tool_calls = []
raw_tool_calls = message.get("tool_calls", [])
for tc in raw_tool_calls:
if tc.get("type") == "function":
func = tc.get("function", {})
try:
arguments = json.loads(func.get("arguments", "{}"))
except json.JSONDecodeError:
arguments = {}
tool_calls.append(ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=arguments,
))
return tool_calls
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
@ -215,6 +268,8 @@ class OpenAIClient(LLMClient):
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters.
Yields:
@ -227,9 +282,14 @@ class OpenAIClient(LLMClient):
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
body = self._build_request_body(
messages, effective_config, stream=True,
tools=tools, tool_choice=tool_choice, **kwargs
)
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
if tools:
logger.info(f"[AC-AISVC-06] Function calling enabled with {len(tools)} tools")
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")

View File

@ -0,0 +1,202 @@
"""
元数据字段定义缓存服务
使用 Redis 缓存 metadata_field_definitions减少数据库查询
"""
import json
import logging
from typing import Any
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class MetadataCacheService:
"""
元数据字段定义缓存服务
缓存策略
- Key: metadata:fields:{tenant_id}
- Value: JSON 序列化的字段定义列表
- TTL: 1小时3600
- 更新策略写时更新 + 定时刷新
"""
CACHE_KEY_PREFIX = "metadata:fields"
DEFAULT_TTL = 3600 # 1小时
def __init__(self):
self._settings = get_settings()
self._redis_client = None
self._enabled = self._settings.redis_enabled
async def _get_redis(self):
"""获取 Redis 连接(延迟初始化)"""
if not self._enabled:
return None
if self._redis_client is None:
try:
import redis.asyncio as redis
self._redis_client = redis.from_url(
self._settings.redis_url,
decode_responses=True
)
except Exception as e:
logger.error(f"[MetadataCache] Failed to connect Redis: {e}")
self._enabled = False
return None
return self._redis_client
def _make_key(self, tenant_id: str) -> str:
"""生成缓存 key"""
return f"{self.CACHE_KEY_PREFIX}:{tenant_id}"
async def get_fields(self, tenant_id: str) -> list[dict[str, Any]] | None:
"""
获取缓存的字段定义
Args:
tenant_id: 租户 ID
Returns:
字段定义列表未缓存返回 None
"""
if not self._enabled:
return None
try:
redis = await self._get_redis()
if not redis:
return None
key = self._make_key(tenant_id)
cached_data = await redis.get(key)
if cached_data:
logger.info(f"[MetadataCache] Cache hit for tenant={tenant_id}")
return json.loads(cached_data)
logger.info(f"[MetadataCache] Cache miss for tenant={tenant_id}")
return None
except Exception as e:
logger.error(f"[MetadataCache] Get cache error: {e}")
return None
async def set_fields(
self,
tenant_id: str,
fields: list[dict[str, Any]],
ttl: int | None = None
) -> bool:
"""
缓存字段定义
Args:
tenant_id: 租户 ID
fields: 字段定义列表
ttl: 过期时间默认 1小时
Returns:
是否成功
"""
if not self._enabled:
return False
try:
redis = await self._get_redis()
if not redis:
return False
key = self._make_key(tenant_id)
ttl = ttl or self.DEFAULT_TTL
await redis.setex(
key,
ttl,
json.dumps(fields, ensure_ascii=False, default=str)
)
logger.info(
f"[MetadataCache] Cached {len(fields)} fields for tenant={tenant_id}, "
f"ttl={ttl}s"
)
return True
except Exception as e:
logger.error(f"[MetadataCache] Set cache error: {e}")
return False
async def invalidate(self, tenant_id: str) -> bool:
"""
使缓存失效字段定义更新时调用
Args:
tenant_id: 租户 ID
Returns:
是否成功
"""
if not self._enabled:
return False
try:
redis = await self._get_redis()
if not redis:
return False
key = self._make_key(tenant_id)
result = await redis.delete(key)
if result:
logger.info(f"[MetadataCache] Invalidated cache for tenant={tenant_id}")
return bool(result)
except Exception as e:
logger.error(f"[MetadataCache] Invalidate error: {e}")
return False
async def invalidate_all(self) -> bool:
"""
使所有元数据缓存失效
Returns:
是否成功
"""
if not self._enabled:
return False
try:
redis = await self._get_redis()
if not redis:
return False
# 查找所有元数据缓存 key
pattern = f"{self.CACHE_KEY_PREFIX}:*"
keys = []
async for key in redis.scan_iter(match=pattern):
keys.append(key)
if keys:
await redis.delete(*keys)
logger.info(f"[MetadataCache] Invalidated {len(keys)} cache entries")
return True
except Exception as e:
logger.error(f"[MetadataCache] Invalidate all error: {e}")
return False
# 全局缓存服务实例
_metadata_cache_service: MetadataCacheService | None = None
async def get_metadata_cache_service() -> MetadataCacheService:
"""获取元数据缓存服务实例(单例)"""
global _metadata_cache_service
if _metadata_cache_service is None:
_metadata_cache_service = MetadataCacheService()
return _metadata_cache_service

View File

@ -40,6 +40,19 @@ class MetadataFieldDefinitionService:
def __init__(self, session: AsyncSession):
self._session = session
async def _invalidate_cache(self, tenant_id: str) -> None:
"""
清除租户的元数据字段缓存
在字段创建更新删除时调用
"""
try:
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
await cache_service.invalidate(tenant_id)
except Exception as e:
# 缓存失效失败不影响主流程
logger.warning(f"[AC-IDSMETA-13] Failed to invalidate cache: {e}")
async def list_field_definitions(
self,
tenant_id: str,
@ -180,6 +193,9 @@ class MetadataFieldDefinitionService:
self._session.add(field)
await self._session.flush()
# 清除缓存,使新字段在下次查询时生效
await self._invalidate_cache(tenant_id)
logger.info(
f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, "
f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}"
@ -223,6 +239,10 @@ class MetadataFieldDefinitionService:
field.is_filterable = field_update.is_filterable
if field_update.is_rank_feature is not None:
field.is_rank_feature = field_update.is_rank_feature
# [AC-MRS-01] 修复:添加 field_roles 更新逻辑
if field_update.field_roles is not None:
self._validate_field_roles(field_update.field_roles)
field.field_roles = field_update.field_roles
if field_update.status is not None:
old_status = field.status
field.status = field_update.status
@ -235,6 +255,9 @@ class MetadataFieldDefinitionService:
field.updated_at = datetime.utcnow()
await self._session.flush()
# 清除缓存,使更新在下次查询时生效
await self._invalidate_cache(tenant_id)
logger.info(
f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, "
f"field_id={field_id}, version={field.version}"

View File

@ -17,6 +17,18 @@ from .memory_adapter import MemoryAdapter, UserMemory
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
from .slot_validation_service import (
SlotValidationService,
ValidationResult,
SlotValidationError,
BatchValidationResult,
SlotValidationErrorCode,
)
from .slot_manager import (
SlotManager,
SlotWriteResult,
create_slot_manager,
)
__all__ = [
"PolicyRouter",
@ -54,4 +66,12 @@ __all__ = [
"RuntimeObserver",
"RuntimeContext",
"get_runtime_observer",
"SlotValidationService",
"ValidationResult",
"SlotValidationError",
"BatchValidationResult",
"SlotValidationErrorCode",
"SlotManager",
"SlotWriteResult",
"create_slot_manager",
]

View File

@ -2,6 +2,10 @@
Agent Orchestrator for Mid Platform.
[AC-MARH-07] ReAct loop with iteration limit (3-5 iterations).
Supports two execution modes:
1. ReAct (Text-based): Traditional Thought/Action/Observation loop
2. Function Calling: Uses LLM's native function calling capability
ReAct Flow:
1. Thought: Agent thinks about what to do
2. Action: Agent decides to use a tool
@ -16,6 +20,8 @@ import re
import time
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
from app.models.mid.schemas import (
@ -26,7 +32,10 @@ from app.models.mid.schemas import (
ToolType,
TraceInfo,
)
from app.services.llm.base import ToolDefinition
from app.services.mid.tool_guide_registry import ToolGuideRegistry, get_tool_guide_registry
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.tool_converter import convert_tools_to_llm_format, build_tool_result_message
from app.services.prompt.template_service import PromptTemplateService
from app.services.prompt.variable_resolver import VariableResolver
@ -36,11 +45,17 @@ DEFAULT_MAX_ITERATIONS = 5
MIN_ITERATIONS = 3
class AgentMode(str, Enum):
"""Agent execution mode."""
REACT = "react"
FUNCTION_CALLING = "function_calling"
@dataclass
class ToolResult:
"""Tool execution result."""
success: bool
output: str | None = None
output: Any = None
error: str | None = None
duration_ms: int = 0
@ -59,6 +74,7 @@ class AgentOrchestrator:
Features:
- ReAct loop with max 5 iterations (min 3)
- Function Calling mode for supported LLMs (OpenAI, DeepSeek, etc.)
- Per-tool timeout (2s) and end-to-end timeout (8s)
- Automatic fallback on iteration limit or timeout
- Template-based prompt with variable injection
@ -70,17 +86,74 @@ class AgentOrchestrator:
timeout_governor: TimeoutGovernor | None = None,
llm_client: Any = None,
tool_registry: Any = None,
guide_registry: ToolGuideRegistry | None = None,
template_service: PromptTemplateService | None = None,
variable_resolver: VariableResolver | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
session_id: str | None = None,
scene: str | None = None,
mode: AgentMode = AgentMode.FUNCTION_CALLING,
):
self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS)
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._llm_client = llm_client
self._tool_registry = tool_registry
self._guide_registry = guide_registry
self._template_service = template_service
self._variable_resolver = variable_resolver or VariableResolver()
self._tenant_id = tenant_id
self._user_id = user_id
self._session_id = session_id
self._scene = scene
self._mode = mode
self._tools_cache: list[ToolDefinition] | None = None
def _get_tools_definition(self) -> list[ToolDefinition]:
"""Get cached tools definition for Function Calling."""
if self._tools_cache is None and self._tool_registry:
tools = self._tool_registry.get_all_tools()
self._tools_cache = convert_tools_to_llm_format(tools)
return self._tools_cache or []
async def _get_tools_definition_async(self) -> list[ToolDefinition]:
"""Get tools definition for Function Calling with dynamic schema support."""
if self._tools_cache is not None:
return self._tools_cache
if not self._tool_registry:
return []
tools = self._tool_registry.get_all_tools()
result = []
for tool in tools:
if tool.name == "kb_search_dynamic" and self._tenant_id:
from app.services.mid.kb_search_dynamic_tool import (
_TOOL_SCHEMA_CACHE,
_TOOL_SCHEMA_CACHE_TTL_SECONDS,
)
import time
cache_key = f"tool_schema:{self._tenant_id}"
current_time = time.time()
if cache_key in _TOOL_SCHEMA_CACHE:
cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key]
if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS:
result.append(ToolDefinition(
name=cached_schema["name"],
description=cached_schema["description"],
parameters=cached_schema["parameters"],
))
continue
result.append(convert_tool_to_llm_format(tool))
else:
result.append(convert_tool_to_llm_format(tool))
self._tools_cache = result
return result
async def execute(
self,
@ -90,7 +163,7 @@ class AgentOrchestrator:
on_action: Any = None,
) -> tuple[str, ReActContext, TraceInfo]:
"""
[AC-MARH-07] Execute ReAct loop with iteration control.
[AC-MARH-07] Execute agent loop with iteration control.
Args:
user_message: User input message
@ -101,6 +174,416 @@ class AgentOrchestrator:
Returns:
Tuple of (final_answer, react_context, trace_info)
"""
if self._mode == AgentMode.FUNCTION_CALLING:
return await self._execute_function_calling(user_message, context, on_action)
else:
return await self._execute_react(user_message, context, on_thought, on_action)
async def _execute_function_calling(
self,
user_message: str,
context: dict[str, Any] | None = None,
on_action: Any = None,
) -> tuple[str, ReActContext, TraceInfo]:
"""
Execute using Function Calling mode.
This mode uses the LLM's native function calling capability,
which is more reliable and token-efficient than text-based ReAct.
"""
react_ctx = ReActContext(max_iterations=self._max_iterations)
tool_calls: list[ToolCallTrace] = []
start_time = time.time()
logger.info(
f"[AC-MARH-07] Starting Function Calling loop: max_iterations={self._max_iterations}, "
f"llm_client={self._llm_client is not None}, tool_registry={self._tool_registry is not None}"
)
if not self._llm_client:
logger.error("[DEBUG-ORCH] LLM client is None, returning error response")
return "抱歉,服务配置错误,请联系管理员。", react_ctx, TraceInfo(
mode=ExecutionMode.AGENT,
request_id=str(uuid.uuid4()),
generation_id=str(uuid.uuid4()),
)
try:
messages: list[dict[str, Any]] = [
{"role": "system", "content": await self._build_system_prompt()},
{"role": "user", "content": user_message},
]
tools = await self._get_tools_definition_async()
overall_start = time.time()
end_to_end_timeout = self._timeout_governor.end_to_end_timeout_seconds
llm_timeout = getattr(self._timeout_governor, 'llm_timeout_seconds', 15.0)
while react_ctx.should_continue and react_ctx.iteration < react_ctx.max_iterations:
react_ctx.iteration += 1
elapsed = time.time() - overall_start
remaining_time = end_to_end_timeout - elapsed
if remaining_time <= 0:
logger.warning("[AC-MARH-09] Function Calling loop exceeded end-to-end timeout")
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
break
logger.info(
f"[AC-MARH-07] Function Calling iteration {react_ctx.iteration}/"
f"{react_ctx.max_iterations}, remaining_time={remaining_time:.1f}s"
)
logger.info(
f"[DEBUG-ORCH] Calling LLM generate with messages_count={len(messages)}, "
f"tools_count={len(tools) if tools else 0}"
)
response = await asyncio.wait_for(
self._llm_client.generate(
messages=messages,
tools=tools if tools else None,
tool_choice="auto" if tools else None,
),
timeout=min(llm_timeout, remaining_time)
)
logger.info(
f"[DEBUG-ORCH] LLM response received: has_tool_calls={response.has_tool_calls}, "
f"content_length={len(response.content) if response.content else 0}, "
f"tool_calls_count={len(response.tool_calls) if response.tool_calls else 0}"
)
if response.has_tool_calls:
for tool_call in response.tool_calls:
tool_name = tool_call.name
tool_args = tool_call.arguments
logger.info(f"[AC-MARH-07] Tool call: {tool_name}, args={tool_args}")
messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [tool_call.to_dict()],
})
tool_result, tool_trace = await self._act_fc(
tool_call.id, tool_name, tool_args, react_ctx
)
tool_calls.append(tool_trace)
react_ctx.tool_calls.append(tool_trace)
if on_action:
await on_action(tool_name, tool_result)
called_tools = {tc.tool_name for tc in react_ctx.tool_calls[:-1]}
is_first_call = tool_name not in called_tools
# Extract tool_guide from output if present (added by _act_fc)
result_output = tool_result.output if tool_result.success else {"error": tool_result.error}
tool_guide = None
if isinstance(result_output, dict) and "_tool_guide" in result_output:
result_output = dict(result_output)
tool_guide = result_output.pop("_tool_guide")
messages.append(build_tool_result_message(
tool_call_id=tool_call.id,
tool_name=tool_name,
result=result_output,
tool_guide=tool_guide,
))
if not tool_result.success:
if tool_trace.status == ToolCallStatus.TIMEOUT:
react_ctx.final_answer = "抱歉,操作超时,请稍后重试或联系人工客服。"
react_ctx.should_continue = False
break
else:
react_ctx.final_answer = response.content or "抱歉,我无法处理您的请求。"
react_ctx.should_continue = False
break
if react_ctx.should_continue and not react_ctx.final_answer:
logger.warning(f"[AC-MARH-07] Function Calling reached max iterations: {react_ctx.iteration}")
react_ctx.final_answer = await self._force_final_answer_fc(messages)
except asyncio.TimeoutError:
logger.error("[AC-MARH-09] Function Calling loop timed out (end-to-end)")
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
tool_calls.append(ToolCallTrace(
tool_name="fc_loop",
tool_type=ToolType.INTERNAL,
duration_ms=int((time.time() - start_time) * 1000),
status=ToolCallStatus.TIMEOUT,
error_code="E2E_TIMEOUT",
))
except Exception as e:
logger.error(f"[AC-MARH-07] Function Calling error: {e}", exc_info=True)
react_ctx.final_answer = f"抱歉,处理过程中发生错误:{str(e)}"
total_duration_ms = int((time.time() - start_time) * 1000)
trace = TraceInfo(
mode=ExecutionMode.AGENT,
request_id=str(uuid.uuid4()),
generation_id=str(uuid.uuid4()),
react_iterations=react_ctx.iteration,
tools_used=[tc.tool_name for tc in tool_calls if tc.tool_name not in ("fc_loop", "react_loop")],
tool_calls=tool_calls if tool_calls else None,
)
logger.info(
f"[AC-MARH-07] Function Calling completed: iterations={react_ctx.iteration}, "
f"duration_ms={total_duration_ms}"
)
return react_ctx.final_answer or "抱歉,我暂时无法处理您的请求。", react_ctx, trace
async def _act_fc(
self,
tool_call_id: str,
tool_name: str,
tool_args: dict[str, Any],
react_ctx: ReActContext,
) -> tuple[ToolResult, ToolCallTrace]:
"""Execute tool in Function Calling mode."""
start_time = time.time()
if not self._tool_registry:
duration_ms = int((time.time() - start_time) * 1000)
return ToolResult(
success=False,
error="Tool registry not configured",
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="NO_REGISTRY",
)
try:
final_args = dict(tool_args)
if self._tenant_id:
final_args["tenant_id"] = self._tenant_id
if self._user_id:
final_args["user_id"] = self._user_id
if self._session_id:
final_args["session_id"] = self._session_id
if tool_name == "kb_search_dynamic":
# 确保 context 存在,供 AI 传入动态过滤条件
if "context" not in final_args:
final_args["context"] = {}
# scene 参数由 AI 从元数据中选择,系统不强制覆盖
logger.info(
f"[AC-MARH-07] FC Tool call starting: tool={tool_name}, "
f"args={tool_args}, final_args={final_args}"
)
result = await asyncio.wait_for(
self._tool_registry.execute(
name=tool_name,
**final_args,
),
timeout=self._timeout_governor.per_tool_timeout_seconds
)
duration_ms = int((time.time() - start_time) * 1000)
called_tools = {tc.tool_name for tc in react_ctx.tool_calls}
is_first_call = tool_name not in called_tools
output = result.output
if is_first_call and result.success:
usage_guide = self._build_tool_usage_guide(tool_name)
if usage_guide:
if isinstance(output, dict):
output = dict(output)
output["_tool_guide"] = usage_guide
elif isinstance(output, str):
output = f"{output}\n\n---\n{usage_guide}"
else:
output = {"result": output, "_tool_guide": usage_guide}
return ToolResult(
success=result.success,
output=output,
error=result.error,
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
args_digest=str(tool_args)[:100] if tool_args else None,
result_digest=str(result.output)[:100] if result.output else None,
arguments=tool_args,
result=output,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(f"[AC-MARH-08] FC Tool timeout: {tool_name}, duration={duration_ms}ms")
return ToolResult(
success=False,
error="Tool timeout",
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="TOOL_TIMEOUT",
arguments=tool_args,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"[AC-MARH-07] FC Tool error: {tool_name}, error={e}")
return ToolResult(
success=False,
error=str(e),
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="TOOL_ERROR",
arguments=tool_args,
)
async def _build_system_prompt(self) -> str:
"""Build system prompt for Function Calling mode with template support."""
default_prompt = """你是一个智能客服助手,正在处理用户请求。
## 决策协议
1. 优先使用已有观察信息避免重复调用同类工具
2. 当问题需要外部事实或结构化状态时再调用工具如果可直接回答则不要调用
3. 缺少关键参数时优先向用户追问不要使用空参数调用工具
4. 工具失败时先说明已尝试再给出降级方案或下一步引导
5. 对用户输出必须拟人自然有同理心不暴露"工具调用/路由/策略"等内部术语
## 知识库查询强制流程
当用户问题需要进行知识库查询kb_search_dynamic必须遵循以下步骤
**步骤1先调用 list_document_metadata_fields**
- 在任何知识库搜索之前必须先调用 `list_document_metadata_fields` 工具
- 获取可用的元数据字段 grade, subject, kb_scene 及其常见取值
**步骤2分析用户意图选择合适的过滤条件**
- 根据用户问题和返回的元数据字段确定合适的过滤条件
- 从元数据字段的 common_values 中选择合适的值
**步骤3调用 kb_search_dynamic 进行搜索**
- 使用步骤1获取的元数据字段构造 context 参数
- scene 参数必须从元数据字段的 kb_scene 常见值中选择不要硬编码
**示例流程**
1. 调用 `list_document_metadata_fields` 获取字段信息
2. 根据返回结果发现可用字段grade年级subject学科kb_scene场景
3. 分析用户问题"三年级语文怎么学"确定过滤条件grade="三年级", subject="语文"
4. kb_scene 的常见值中选择合适的 scene"学习方案"
5. 调用 `kb_search_dynamic`传入构造好的 context scene
## 注意事项
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields
"""
if not self._template_service or not self._tenant_id:
return default_prompt
try:
from app.core.database import get_session
from app.core.prompts import SYSTEM_PROMPT
async with get_session() as session:
template_service = PromptTemplateService(session)
base_prompt = await template_service.get_published_template(
tenant_id=self._tenant_id,
scene="agent_fc",
resolver=self._variable_resolver,
)
if not base_prompt or base_prompt == SYSTEM_PROMPT:
base_prompt = await template_service.get_published_template(
tenant_id=self._tenant_id,
scene="default",
resolver=self._variable_resolver,
)
if not base_prompt or base_prompt == SYSTEM_PROMPT:
logger.info("[AC-MARH-07] No published template found for agent_fc or default, using default prompt")
return default_prompt
agent_protocol = """
## 智能体决策协议
1. 优先使用已有观察信息避免重复调用同类工具
2. 当问题需要外部事实或结构化状态时再调用工具如果可直接回答则不要调用
3. 缺少关键参数时优先向用户追问不要使用空参数调用工具
4. 工具失败时先说明已尝试再给出降级方案或下一步引导
## 知识库查询强制流程
当用户问题需要进行知识库查询kb_search_dynamic必须遵循以下步骤
**步骤1先调用 list_document_metadata_fields**
- 在任何知识库搜索之前必须先调用 `list_document_metadata_fields` 工具
- 获取可用的元数据字段 grade, subject, kb_scene 及其常见取值
**步骤2分析用户意图选择合适的过滤条件**
- 根据用户问题和返回的元数据字段确定合适的过滤条件
- 从元数据字段的 common_values 中选择合适的值
**步骤3调用 kb_search_dynamic 进行搜索**
- 使用步骤1获取的元数据字段构造 context 参数
- scene 参数必须从元数据字段的 kb_scene 常见值中选择不要硬编码
## 注意事项
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields
"""
final_prompt = base_prompt + agent_protocol
logger.info(f"[AC-MARH-07] Loaded template for tenant={self._tenant_id}")
return final_prompt
except Exception as e:
logger.warning(f"[AC-MARH-07] Failed to load template, using default: {e}")
return default_prompt
async def _force_final_answer_fc(self, messages: list[dict[str, Any]]) -> str:
"""Force final answer when max iterations reached in Function Calling mode."""
try:
response = await self._llm_client.generate(
messages=messages + [{"role": "user", "content": "请基于以上信息给出最终回答,不要再调用工具。"}],
tools=None,
)
return response.content or "抱歉,我已经尽力处理您的请求,但可能需要更多信息。"
except Exception as e:
logger.error(f"[AC-MARH-07] Force final answer FC failed: {e}")
return "抱歉,我已经尽力处理您的请求,但可能需要更多信息。请稍后重试或联系人工客服。"
async def _execute_react(
self,
user_message: str,
context: dict[str, Any] | None = None,
on_thought: Any = None,
on_action: Any = None,
) -> tuple[str, ReActContext, TraceInfo]:
"""
Execute using traditional ReAct mode (text-based).
This is the original implementation for backward compatibility.
"""
react_ctx = ReActContext(max_iterations=self._max_iterations)
tool_calls: list[ToolCallTrace] = []
start_time = time.time()
@ -321,58 +804,108 @@ Action Input:
return default_template
def _build_tools_section(self) -> str:
"""Build rich tools section for ReAct prompt."""
"""
Build compact tools section for ReAct prompt.
Only includes tool name and brief description for initial scanning.
Detailed usage guides are disclosed on-demand when tool is called.
"""
if not self._tool_registry:
return "当前没有可用的工具。"
tools = self._tool_registry.list_tools(enabled_only=True)
tools = self._tool_registry.get_all_tools()
if not tools:
return "当前没有可用的工具。"
lines = ["## 可用工具列表", "", "以下是你可以使用的工具,只能使用这些工具", ""]
lines = ["## 可用工具列表", "", "以下是你可以使用的工具", ""]
for tool in tools:
tool_guide = self._guide_registry.get_tool_guide(tool.name) if self._guide_registry else None
description = tool_guide.description if tool_guide else tool.description
lines.append(f"- **{tool.name}**: {description}")
lines.append("")
lines.append("调用工具时,系统会提供该工具的详细使用说明。")
return "\n".join(lines)
def _build_tool_usage_guide(self, tool_name: str) -> str:
"""
Build detailed usage guide for a specific tool.
Called when the tool is executed to provide on-demand guidance.
"""
tool_guide = self._guide_registry.get_tool_guide(tool_name) if self._guide_registry else None
tool = self._tool_registry.get_tool(tool_name) if self._tool_registry else None
if not tool_guide and not tool:
return ""
lines = [f"## {tool_name} 使用说明", ""]
if tool_guide:
lines.append(f"**用途**: {tool_guide.description}")
lines.append("")
if tool_guide.triggers:
lines.append("**适用场景**:")
for trigger in tool_guide.triggers:
lines.append(f"- {trigger}")
lines.append("")
if tool_guide.anti_triggers:
lines.append("**不适用场景**:")
for anti in tool_guide.anti_triggers:
lines.append(f"- {anti}")
lines.append("")
if tool_guide.content:
lines.append(tool_guide.content)
lines.append("")
if tool:
meta = tool.metadata or {}
lines.append(f"### {tool.name}")
lines.append(f"用途: {tool.description}")
when_to_use = meta.get("when_to_use")
when_not_to_use = meta.get("when_not_to_use")
if when_to_use:
lines.append(f"何时使用: {when_to_use}")
if when_not_to_use:
lines.append(f"何时不要使用: {when_not_to_use}")
params = meta.get("parameters")
if isinstance(params, dict):
properties = params.get("properties", {})
required = params.get("required", [])
if properties:
lines.append("参数:")
lines.append("**参数说明**:")
for param_name, param_info in properties.items():
param_desc = param_info.get("description", "") if isinstance(param_info, dict) else ""
line = f" - {param_name}: {param_desc}".strip()
req_mark = " (必填)" if param_name in required else ""
if param_name == "tenant_id":
line += " (系统注入,模型不要填写)"
elif param_name in required:
line += " (必填)"
lines.append(line)
req_mark = " (系统注入)"
lines.append(f"- `{param_name}`: {param_desc}{req_mark}")
lines.append("")
if meta.get("example_action_input"):
lines.append("示例入参(JSON):")
lines.append("**调用示例**:")
try:
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False)
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False, indent=2)
except Exception:
example_text = str(meta["example_action_input"])
lines.append(example_text)
lines.append(f"```json\n{example_text}\n```")
lines.append("")
if meta.get("result_interpretation"):
lines.append(f"结果解释: {meta['result_interpretation']}")
lines.append(f"**结果说明**: {meta['result_interpretation']}")
lines.append("")
return "\n".join(lines)
def _build_tools_guide_section(self, tool_names: list[str] | None = None) -> str:
"""
Build detailed tools guide section for ReAct prompt.
This provides comprehensive usage guides from ToolRegistry.
Called separately from _build_tools_section for flexibility.
Args:
tool_names: If provided, only include tools for these names.
If None, include all tools.
"""
return self._guide_registry.build_tools_prompt_section(tool_names)
def _extract_json_object(self, text: str) -> dict[str, Any] | None:
"""Extract the first valid JSON object from free text."""
candidates = []
@ -438,6 +971,8 @@ Action Input:
) -> tuple[ToolResult, ToolCallTrace]:
"""
[AC-MARH-07, AC-MARH-08] Execute tool action with timeout.
On first call to a tool, appends detailed usage guide to observation.
"""
tool_name = thought.action or "unknown"
start_time = time.time()
@ -461,6 +996,23 @@ Action Input:
if self._tenant_id:
tool_args["tenant_id"] = self._tenant_id
if self._user_id:
tool_args["user_id"] = self._user_id
if self._session_id:
tool_args["session_id"] = self._session_id
if tool_name == "kb_search_dynamic":
# 确保 context 存在,供 AI 传入动态过滤条件
if "context" not in tool_args:
tool_args["context"] = {}
# scene 参数由 AI 从元数据中选择,系统不强制覆盖
logger.info(
f"[AC-MARH-07] Tool call starting: tool={tool_name}, "
f"action_input={thought.action_input}, final_args={tool_args}"
)
result = await asyncio.wait_for(
self._tool_registry.execute(
tool_name=tool_name,
@ -470,9 +1022,25 @@ Action Input:
)
duration_ms = int((time.time() - start_time) * 1000)
called_tools = {tc.tool_name for tc in react_ctx.tool_calls}
is_first_call = tool_name not in called_tools
output = result.output
if is_first_call and result.success:
usage_guide = self._build_tool_usage_guide(tool_name)
if usage_guide:
if isinstance(output, dict):
output = dict(output)
output["_tool_guide"] = usage_guide
elif isinstance(output, str):
output = f"{output}\n\n---\n{usage_guide}"
else:
output = {"result": output, "_tool_guide": usage_guide}
return ToolResult(
success=result.success,
output=result.output,
output=output,
error=result.error,
duration_ms=duration_ms,
), ToolCallTrace(
@ -482,6 +1050,8 @@ Action Input:
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
args_digest=str(thought.action_input)[:100] if thought.action_input else None,
result_digest=str(result.output)[:100] if result.output else None,
arguments=thought.action_input,
result=output,
)
except asyncio.TimeoutError:
@ -497,6 +1067,7 @@ Action Input:
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="TOOL_TIMEOUT",
arguments=thought.action_input,
)
except Exception as e:
@ -512,6 +1083,7 @@ Action Input:
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="TOOL_ERROR",
arguments=thought.action_input,
)
async def _force_final_answer(

View File

@ -0,0 +1,342 @@
"""
Batch Ask-Back Service.
批量追问服务 - 支持一次追问多个缺失槽位
[AC-MRS-SLOT-ASKBACK-01] 批量追问
职责
1. 支持一次追问多个缺失槽位
2. 选择策略必填优先场景相关优先最近未追问过优先
3. 输出形式单条自然语言合并提问或分段提问
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.cache.slot_state_cache import get_slot_state_cache
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
@dataclass
class AskBackSlot:
"""
追问槽位信息
Attributes:
slot_key: 槽位键名
label: 显示标签
ask_back_prompt: 追问提示
priority: 优先级数值越大越优先
last_asked_at: 上次追问时间戳
is_required: 是否必填
scene_relevance: 场景相关度
"""
slot_key: str
label: str
ask_back_prompt: str | None = None
priority: int = 0
last_asked_at: float | None = None
is_required: bool = False
scene_relevance: float = 0.0
@dataclass
class BatchAskBackConfig:
"""
批量追问配置
Attributes:
max_ask_back_slots_per_turn: 每轮最多追问槽位数
prefer_required: 是否优先追问必填槽位
prefer_scene_relevant: 是否优先追问场景相关槽位
avoid_recent_asked: 是否避免最近追问过的槽位
recent_asked_threshold_seconds: 最近追问阈值
merge_prompts: 是否合并追问提示
merge_template: 合并模板
"""
max_ask_back_slots_per_turn: int = 2
prefer_required: bool = True
prefer_scene_relevant: bool = True
avoid_recent_asked: bool = True
recent_asked_threshold_seconds: float = 60.0
merge_prompts: bool = True
merge_template: str = "为了更好地为您服务,请告诉我:{prompts}"
@dataclass
class BatchAskBackResult:
"""
批量追问结果
Attributes:
selected_slots: 选中的追问槽位列表
prompts: 追问提示列表
merged_prompt: 合并后的追问提示
ask_back_count: 追问数量
"""
selected_slots: list[AskBackSlot] = field(default_factory=list)
prompts: list[str] = field(default_factory=list)
merged_prompt: str | None = None
ask_back_count: int = 0
def has_ask_back(self) -> bool:
return self.ask_back_count > 0
def get_prompt(self) -> str:
"""获取最终追问提示"""
if self.merged_prompt:
return self.merged_prompt
if self.prompts:
return self.prompts[0]
return "请提供更多信息以便我更好地帮助您。"
class BatchAskBackService:
"""
[AC-MRS-SLOT-ASKBACK-01] 批量追问服务
支持一次追问多个缺失槽位提高补全效率
"""
ASK_BACK_HISTORY_KEY_PREFIX = "slot_ask_back_history"
ASK_BACK_HISTORY_TTL = 300
def __init__(
self,
session: AsyncSession,
tenant_id: str,
session_id: str,
config: BatchAskBackConfig | None = None,
):
self._session = session
self._tenant_id = tenant_id
self._session_id = session_id
self._config = config or BatchAskBackConfig()
self._slot_def_service = SlotDefinitionService(session)
self._cache = get_slot_state_cache()
async def generate_batch_ask_back(
self,
missing_slots: list[dict[str, str]],
current_scene: str | None = None,
) -> BatchAskBackResult:
"""
生成批量追问
Args:
missing_slots: 缺失槽位列表
current_scene: 当前场景
Returns:
BatchAskBackResult: 批量追问结果
"""
if not missing_slots:
return BatchAskBackResult()
ask_back_slots = await self._prepare_ask_back_slots(missing_slots, current_scene)
selected_slots = self._select_slots_for_ask_back(ask_back_slots)
asked_history = await self._get_asked_history()
selected_slots = self._filter_recently_asked(selected_slots, asked_history)
if not selected_slots:
selected_slots = ask_back_slots[:self._config.max_ask_back_slots_per_turn]
prompts = self._generate_prompts(selected_slots)
merged_prompt = self._merge_prompts(prompts) if self._config.merge_prompts else None
await self._record_ask_back_history([s.slot_key for s in selected_slots])
return BatchAskBackResult(
selected_slots=selected_slots,
prompts=prompts,
merged_prompt=merged_prompt,
ask_back_count=len(selected_slots),
)
async def _prepare_ask_back_slots(
self,
missing_slots: list[dict[str, str]],
current_scene: str | None,
) -> list[AskBackSlot]:
"""准备追问槽位列表"""
ask_back_slots = []
for missing in missing_slots:
slot_key = missing.get("slot_key", "")
label = missing.get("label", slot_key)
ask_back_prompt = missing.get("ask_back_prompt")
field_key = missing.get("field_key")
slot_def = await self._slot_def_service.get_slot_definition_by_key(
self._tenant_id, slot_key
)
is_required = False
scene_relevance = 0.0
if slot_def:
is_required = slot_def.required
if not ask_back_prompt:
ask_back_prompt = slot_def.ask_back_prompt
if current_scene and slot_def.scene_scope:
if current_scene in slot_def.scene_scope:
scene_relevance = 1.0
priority = self._calculate_priority(is_required, scene_relevance)
ask_back_slots.append(AskBackSlot(
slot_key=slot_key,
label=label,
ask_back_prompt=ask_back_prompt,
priority=priority,
is_required=is_required,
scene_relevance=scene_relevance,
))
return ask_back_slots
def _calculate_priority(self, is_required: bool, scene_relevance: float) -> int:
"""计算槽位优先级"""
priority = 0
if self._config.prefer_required and is_required:
priority += 100
if self._config.prefer_scene_relevant:
priority += int(scene_relevance * 50)
return priority
def _select_slots_for_ask_back(
self,
ask_back_slots: list[AskBackSlot],
) -> list[AskBackSlot]:
"""选择要追问的槽位"""
sorted_slots = sorted(
ask_back_slots,
key=lambda s: s.priority,
reverse=True,
)
return sorted_slots[:self._config.max_ask_back_slots_per_turn]
async def _get_asked_history(self) -> dict[str, float]:
"""获取最近追问历史"""
history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}"
try:
client = await self._cache._get_client()
if client is None:
return {}
import json
data = await client.get(history_key)
if data:
return json.loads(data)
except Exception as e:
logger.warning(f"[BatchAskBack] Failed to get asked history: {e}")
return {}
async def _record_ask_back_history(self, slot_keys: list[str]) -> None:
"""记录追问历史"""
history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}"
try:
client = await self._cache._get_client()
if client is None:
return
history = await self._get_asked_history()
current_time = time.time()
for slot_key in slot_keys:
history[slot_key] = current_time
import json
await client.setex(
history_key,
self.ASK_BACK_HISTORY_TTL,
json.dumps(history),
)
except Exception as e:
logger.warning(f"[BatchAskBack] Failed to record asked history: {e}")
def _filter_recently_asked(
self,
slots: list[AskBackSlot],
asked_history: dict[str, float],
) -> list[AskBackSlot]:
"""过滤最近追问过的槽位"""
if not self._config.avoid_recent_asked:
return slots
current_time = time.time()
threshold = self._config.recent_asked_threshold_seconds
return [
slot for slot in slots
if slot.slot_key not in asked_history or
current_time - asked_history[slot.slot_key] > threshold
]
def _generate_prompts(self, slots: list[AskBackSlot]) -> list[str]:
"""生成追问提示列表"""
prompts = []
for slot in slots:
if slot.ask_back_prompt:
prompts.append(slot.ask_back_prompt)
else:
prompts.append(f"请告诉我您的{slot.label}")
return prompts
def _merge_prompts(self, prompts: list[str]) -> str | None:
"""合并追问提示"""
if not prompts:
return None
if len(prompts) == 1:
return prompts[0]
if len(prompts) == 2:
return f"{prompts[0]},以及{prompts[1]}"
all_but_last = "".join(prompts[:-1])
return f"{all_but_last},以及{prompts[-1]}"
def create_batch_ask_back_service(
session: AsyncSession,
tenant_id: str,
session_id: str,
config: BatchAskBackConfig | None = None,
) -> BatchAskBackService:
"""
创建批量追问服务实例
Args:
session: 数据库会话
tenant_id: 租户 ID
session_id: 会话 ID
config: 配置
Returns:
BatchAskBackService: 批量追问服务实例
"""
return BatchAskBackService(
session=session,
tenant_id=tenant_id,
session_id=session_id,
config=config,
)

View File

@ -17,15 +17,20 @@ import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
from app.services.mid.metadata_filter_builder import (
FilterBuildResult,
FilterFieldInfo,
MetadataFilterBuilder,
)
from app.services.mid.slot_state_aggregator import (
SlotState,
SlotStateAggregator,
)
from app.services.mid.timeout_governor import TimeoutGovernor
if TYPE_CHECKING:
@ -37,6 +42,9 @@ DEFAULT_TOP_K = 5
DEFAULT_TIMEOUT_MS = 2000
KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic"
_TOOL_SCHEMA_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
_TOOL_SCHEMA_CACHE_TTL_SECONDS = 300 # 5 minutes
@dataclass
class KbSearchDynamicResult:
@ -46,9 +54,21 @@ class KbSearchDynamicResult:
applied_filter: dict[str, Any] = field(default_factory=dict)
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
filter_debug: dict[str, Any] = field(default_factory=dict)
filter_sources: dict[str, str] = field(default_factory=dict) # [AC-SCENE-SLOT-02] 过滤条件来源
fallback_reason_code: str | None = None
duration_ms: int = 0
tool_trace: ToolCallTrace | None = None
step_kb_binding: dict[str, Any] | None = None # [Step-KB-Binding] 步骤知识库绑定信息
@dataclass
class StepKbConfig:
"""[Step-KB-Binding] 步骤级别的知识库配置。"""
allowed_kb_ids: list[str] | None = None
preferred_kb_ids: list[str] | None = None
kb_query_hint: str | None = None
max_kb_calls: int = 1
step_id: str | None = None
@dataclass
@ -86,12 +106,14 @@ class KbSearchDynamicTool:
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: KbSearchDynamicConfig | None = None,
slot_state_aggregator: SlotStateAggregator | None = None,
):
self._session = session
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._config = config or KbSearchDynamicConfig()
self._vector_retriever = None
self._filter_builder: MetadataFilterBuilder | None = None
self._slot_state_aggregator = slot_state_aggregator
@property
def name(self) -> str:
@ -140,6 +162,128 @@ class KbSearchDynamicTool:
},
}
async def get_dynamic_tool_schema(self, tenant_id: str) -> dict[str, Any]:
"""
获取动态生成的工具 Schema包含租户的元数据过滤字段
使用缓存机制避免每次都查询数据库
只显示关联知识库的元数据过滤字段field_roles 包含 resource_filter
Args:
tenant_id: 租户 ID
Returns:
动态生成的工具 Schema
"""
import time
current_time = time.time()
cache_key = f"tool_schema:{tenant_id}"
if cache_key in _TOOL_SCHEMA_CACHE:
cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key]
if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS:
logger.debug(f"[AC-MARH-05] Tool schema cache hit for tenant={tenant_id}")
return cached_schema
logger.info(f"[AC-MARH-05] Building dynamic tool schema for tenant={tenant_id}")
base_properties = {
"query": {
"type": "string",
"description": "检索查询文本",
},
"top_k": {
"type": "integer",
"description": "返回结果数量",
"default": DEFAULT_TOP_K,
},
}
required_fields = ["query"]
context_properties = {}
try:
if self._filter_builder is None:
self._filter_builder = MetadataFilterBuilder(self._session)
filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id)
for field_info in filterable_fields:
field_schema = self._build_field_schema(field_info)
context_properties[field_info.field_key] = field_schema
if field_info.required:
required_fields.append(field_info.field_key)
logger.info(
f"[AC-MARH-05] Dynamic schema built: tenant={tenant_id}, "
f"context_fields={len(context_properties)}, required={required_fields}"
)
except Exception as e:
logger.warning(f"[AC-MARH-05] Failed to get filterable fields: {e}, using base schema")
if context_properties:
base_properties["context"] = {
"type": "object",
"description": "过滤条件,根据用户意图选择合适的字段传递",
"properties": context_properties,
}
else:
base_properties["context"] = {
"type": "object",
"description": "过滤条件(当前租户未配置元数据字段)",
}
schema = {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": base_properties,
"required": required_fields,
},
}
_TOOL_SCHEMA_CACHE[cache_key] = (current_time, schema)
return schema
def _build_field_schema(self, field_info: "FilterFieldInfo") -> dict[str, Any]:
"""
根据字段信息构建 JSON Schema
Args:
field_info: 字段信息
Returns:
字段的 JSON Schema
"""
schema: dict[str, Any] = {
"description": field_info.label or field_info.field_key,
}
if field_info.required:
schema["description"] += "(必填)"
field_type = field_info.field_type.lower() if field_info.field_type else "string"
if field_type in ("enum", "select", "array_enum", "multi_select"):
schema["type"] = "string"
if field_info.options:
schema["enum"] = field_info.options
schema["description"] += f",可选值:{', '.join(field_info.options)}"
elif field_type in ("number", "integer", "float"):
schema["type"] = "number"
elif field_type == "boolean":
schema["type"] = "boolean"
else:
schema["type"] = "string"
if field_info.default_value is not None:
schema["default"] = field_info.default_value
return schema
async def execute(
self,
query: str,
@ -147,16 +291,24 @@ class KbSearchDynamicTool:
scene: str = "open_consult",
top_k: int | None = None,
context: dict[str, Any] | None = None,
slot_state: SlotState | None = None,
step_kb_config: StepKbConfig | None = None,
slot_policy: Literal["flow_strict", "agent_relaxed"] = "flow_strict",
) -> KbSearchDynamicResult:
"""
[AC-MARH-05] 执行 KB 动态检索
[AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级
[Step-KB-Binding] 支持步骤级别的知识库约束
Args:
query: 检索查询
tenant_id: 租户 ID
scene: 场景标识
scene: 场景标识默认值会被 context 中的 scene 覆盖
top_k: 返回数量
context: 上下文包含动态过滤值
context: 上下文包含动态过滤值包括 scene
slot_state: 预聚合的槽位状态可选优先使用
step_kb_config: 步骤级别的知识库配置可选
slot_policy: 槽位策略flow_strict=流程严格模式agent_relaxed=通用问答宽松模式
Returns:
KbSearchDynamicResult 包含检索结果和追踪信息
@ -171,14 +323,357 @@ class KbSearchDynamicTool:
start_time = time.time()
top_k = top_k or self._config.top_k
effective_context = dict(context) if context else {}
effective_scene = effective_context.get("scene", scene)
# [Step-KB-Binding] 记录步骤知识库约束
step_kb_binding_info: dict[str, Any] = {}
if step_kb_config:
step_kb_binding_info = {
"step_id": step_kb_config.step_id,
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
"kb_query_hint": step_kb_config.kb_query_hint,
"max_kb_calls": step_kb_config.max_kb_calls,
}
logger.info(
f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, "
f"query={query[:50]}..., scene={scene}, top_k={top_k}"
f"[Step-KB-Binding] Step KB config applied: "
f"step_id={step_kb_config.step_id}, "
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}, "
f"preferred_kb_ids={step_kb_config.preferred_kb_ids}"
)
filter_result: FilterBuildResult | None = None
logger.info(
f"[AC-MARH-05] 开始执行KB动态检索: tenant={tenant_id}, "
f"query={query[:50]}..., scene={effective_scene}, top_k={top_k}, "
f"slot_policy={slot_policy}, context_keys={list(effective_context.keys())}"
)
# [AC-MRS-SLOT-META-02] 如果提供了 slot_state优先使用
if slot_state is not None:
logger.info(
f"[AC-MRS-SLOT-META-02] Using provided slot_state: "
f"filled={len(slot_state.filled_slots)}, "
f"missing={len(slot_state.missing_required_slots)}"
)
# 检查是否有缺失的必填槽位(仅在流程严格模式下阻断)
if slot_state.missing_required_slots:
if slot_policy == "flow_strict":
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-MRS-SLOT-META-03] 流程严格模式命中缺失必填槽位,触发追问: "
f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="MISSING_REQUIRED_SLOTS",
args_digest=f"query={query[:50]}, scene={effective_scene}",
result_digest=f"missing={len(slot_state.missing_required_slots)}",
arguments={"query": query, "scene": effective_scene, "context": context},
result={"missing_required_slots": slot_state.missing_required_slots},
)
return KbSearchDynamicResult(
success=False,
applied_filter={},
missing_required_slots=slot_state.missing_required_slots,
filter_debug={
"source": "slot_state",
"filled_slots": slot_state.filled_slots,
"slot_to_field_map": slot_state.slot_to_field_map,
},
fallback_reason_code="MISSING_REQUIRED_SLOTS",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
logger.info(
f"[AC-MRS-SLOT-META-03] 通用问答宽松模式检测到缺失槽位但不阻断检索: "
f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}"
)
# 使用 slot_state 构建 filter
metadata_filter, filter_sources = await self._build_filter_from_slot_state(
tenant_id=tenant_id,
slot_state=slot_state,
context=effective_context,
scene_slot_context=effective_context.get("scene_slot_context"), # [AC-SCENE-SLOT-02]
)
else:
# 原有逻辑:构建元数据 filter
# 如果 context 简单(只有键值对),直接构造 filter跳过 MetadataFilterBuilder
metadata_filter = await self._build_filter_legacy(
tenant_id=tenant_id,
context=effective_context,
query=query,
effective_scene=effective_scene,
start_time=start_time,
)
if isinstance(metadata_filter, KbSearchDynamicResult):
# 有错误,直接返回
return metadata_filter
try:
hits = await self._retrieve_with_timeout(
tenant_id=tenant_id,
query=query,
metadata_filter=metadata_filter,
top_k=top_k,
step_kb_config=step_kb_config,
)
duration_ms = int((time.time() - start_time) * 1000)
kb_hit = len(hits) > 0
# [Step-KB-Binding] 记录命中的知识库
hit_kb_ids = list(set(hit.get("kb_id") for hit in hits if hit.get("kb_id")))
if step_kb_binding_info:
step_kb_binding_info["used_kb_ids"] = hit_kb_ids
step_kb_binding_info["kb_hit"] = kb_hit
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK,
args_digest=f"query={query[:50]}, scene={effective_scene}",
result_digest=f"hits={len(hits)}",
arguments={"query": query, "scene": effective_scene, "context": context},
result={"hits_count": len(hits), "kb_hit": kb_hit},
)
logger.info(
f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, "
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}, "
f"hit_kb_ids={hit_kb_ids}"
)
# 确定 filter 来源用于调试
filter_source = "slot_state" if slot_state is not None else "builder"
kb_filter_sources = filter_sources if slot_state is not None else {}
return KbSearchDynamicResult(
success=True,
hits=hits,
applied_filter=metadata_filter or {},
filter_debug={"source": filter_source, "filter_sources": kb_filter_sources},
duration_ms=duration_ms,
tool_trace=tool_trace,
step_kb_binding=step_kb_binding_info if step_kb_binding_info else None,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-MARH-06] KB dynamic search timeout: tenant={tenant_id}, "
f"duration_ms={duration_ms}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="KB_TIMEOUT",
arguments={"query": query, "scene": effective_scene, "context": context},
)
return KbSearchDynamicResult(
success=False,
applied_filter=metadata_filter or {},
missing_required_slots=[],
filter_debug={"error": "timeout"},
fallback_reason_code="KB_TIMEOUT",
duration_ms=duration_ms,
tool_trace=tool_trace,
step_kb_binding=step_kb_binding_info if step_kb_binding_info else None,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-MARH-06] KB dynamic search failed: tenant={tenant_id}, "
f"error={e}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="KB_ERROR",
arguments={"query": query, "scene": effective_scene, "context": context},
)
return KbSearchDynamicResult(
success=False,
applied_filter=metadata_filter or {},
missing_required_slots=[],
filter_debug={"error": str(e)},
fallback_reason_code="KB_ERROR",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
async def _build_filter_legacy(
self,
tenant_id: str,
context: dict[str, Any],
query: str,
effective_scene: str,
start_time: float,
) -> dict[str, Any] | KbSearchDynamicResult:
"""
[AC-MRS-SLOT-META-02] 原有逻辑构建元数据 filter
Returns:
dict: 构建成功的 filter
KbSearchDynamicResult: 构建失败时的错误结果
"""
metadata_filter: dict[str, Any] | None = None
# 简单 context直接构造 filter信任 AI 传入的值)
# 复杂场景:使用 MetadataFilterBuilder 进行严格验证
is_simple_context = all(
isinstance(v, (str, int, float, bool))
for v in context.values()
)
if is_simple_context:
# 直接构造 filter不查询数据库
metadata_filter = context
logger.info(
f"[AC-MARH-05] Using simple context as filter directly: "
f"{metadata_filter}"
)
else:
# 复杂 context使用 MetadataFilterBuilder
filter_result = await self._build_filter_with_builder(
tenant_id=tenant_id,
context=context,
query=query,
effective_scene=effective_scene,
start_time=start_time,
)
if isinstance(filter_result, KbSearchDynamicResult):
# 有错误,直接返回
return filter_result
metadata_filter = filter_result
return metadata_filter or {}
async def _build_filter_from_slot_state(
self,
tenant_id: str,
slot_state: SlotState,
context: dict[str, Any],
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
) -> tuple[dict[str, Any], dict[str, str]]:
"""
[AC-MRS-SLOT-META-02] 基于槽位状态构建过滤条件
[AC-SCENE-SLOT-02] 支持场景槽位包配置的优先级
过滤值来源优先级
1. 已确认槽位值slot_state.filled_slots
2. 当前请求 context 显式值
3. 元数据默认值
Args:
tenant_id: 租户 ID
slot_state: 槽位状态
context: 上下文
scene_slot_context: 场景槽位上下文
Returns:
(过滤条件字典, 过滤来源映射)
"""
if self._filter_builder is None:
self._filter_builder = MetadataFilterBuilder(self._session)
# 获取可过滤字段定义
filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id)
applied_filter: dict[str, Any] = {}
filter_debug_sources: dict[str, str] = {}
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,优先处理场景定义的槽位
scene_slot_keys = set()
if scene_slot_context:
scene_slot_keys = set(scene_slot_context.get_all_slot_keys())
logger.debug(
f"[AC-SCENE-SLOT-02] Processing scene slots: "
f"scene={scene_slot_context.scene_key}, slots={scene_slot_keys}"
)
for field_info in filterable_fields:
field_key = field_info.field_key
value = None
source = None
# 优先级 1: 已确认槽位值(通过 slot_to_field_map 映射)
if slot_state.slot_to_field_map:
# 查找哪个 slot 映射到这个 field
for slot_key, mapped_field_key in slot_state.slot_to_field_map.items():
if mapped_field_key == field_key and slot_key in slot_state.filled_slots:
value = slot_state.filled_slots[slot_key]
source = "slot"
break
# 如果 slot 映射没有命中,直接检查 slot_key 是否等于 field_key
if value is None and field_key in slot_state.filled_slots:
value = slot_state.filled_slots[field_key]
source = "slot"
# 优先级 2: 当前请求 context 显式值
if value is None and field_key in context:
value = context[field_key]
source = "context"
# 优先级 3: 元数据默认值
if value is None and field_info.default_value is not None:
value = field_info.default_value
source = "default"
# 构建过滤条件
if value is not None:
filter_value = self._filter_builder._build_field_filter(
field_info, value
)
if filter_value is not None:
applied_filter[field_key] = filter_value
filter_debug_sources[field_key] = source
logger.info(
f"[AC-MRS-SLOT-META-02] Filter built from slot_state: "
f"fields={len(applied_filter)}, sources={filter_debug_sources}"
)
return applied_filter, filter_debug_sources
async def _build_filter_with_builder(
self,
tenant_id: str,
context: dict[str, Any],
query: str,
effective_scene: str,
start_time: float,
) -> dict[str, Any] | KbSearchDynamicResult:
"""
使用 MetadataFilterBuilder 构建 filter复杂场景
Returns:
dict: 构建成功的 filter
KbSearchDynamicResult: 构建失败时的错误结果
"""
from app.services.mid.metadata_filter_builder import FilterBuildResult
if self._filter_builder is None:
self._filter_builder = MetadataFilterBuilder(self._session)
@ -200,8 +695,10 @@ class KbSearchDynamicTool:
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="MISSING_REQUIRED_SLOTS",
args_digest=f"query={query[:50]}, scene={scene}",
args_digest=f"query={query[:50]}, scene={effective_scene}",
result_digest=f"missing={len(filter_result.missing_required_slots)}",
arguments={"query": query, "scene": effective_scene, "context": context},
result={"missing_required_slots": filter_result.missing_required_slots},
)
return KbSearchDynamicResult(
@ -214,90 +711,7 @@ class KbSearchDynamicTool:
tool_trace=tool_trace,
)
metadata_filter = filter_result.applied_filter if filter_result.success else None
hits = await self._retrieve_with_timeout(
tenant_id=tenant_id,
query=query,
metadata_filter=metadata_filter,
top_k=top_k,
)
duration_ms = int((time.time() - start_time) * 1000)
kb_hit = len(hits) > 0
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK,
args_digest=f"query={query[:50]}, scene={scene}",
result_digest=f"hits={len(hits)}",
)
logger.info(
f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, "
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
)
return KbSearchDynamicResult(
success=True,
hits=hits,
applied_filter=filter_result.applied_filter if filter_result else {},
filter_debug=filter_result.debug_info if filter_result else {},
duration_ms=duration_ms,
tool_trace=tool_trace,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-MARH-06] KB dynamic search timeout: tenant={tenant_id}, "
f"duration_ms={duration_ms}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="KB_TIMEOUT",
)
return KbSearchDynamicResult(
success=False,
applied_filter=filter_result.applied_filter if filter_result else {},
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
filter_debug=filter_result.debug_info if filter_result else {},
fallback_reason_code="KB_TIMEOUT",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-MARH-06] KB dynamic search failed: tenant={tenant_id}, "
f"error={e}"
)
tool_trace = ToolCallTrace(
tool_name=self.name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="KB_ERROR",
)
return KbSearchDynamicResult(
success=False,
applied_filter=filter_result.applied_filter if filter_result else {},
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
filter_debug={"error": str(e)},
fallback_reason_code="KB_ERROR",
duration_ms=duration_ms,
tool_trace=tool_trace,
)
return filter_result.applied_filter if filter_result.success else {}
async def _retrieve_with_timeout(
self,
@ -305,13 +719,14 @@ class KbSearchDynamicTool:
query: str,
metadata_filter: dict[str, Any] | None = None,
top_k: int = DEFAULT_TOP_K,
step_kb_config: StepKbConfig | None = None,
) -> list[dict[str, Any]]:
"""带超时控制的检索。"""
timeout_seconds = self._config.timeout_ms / 1000.0
try:
return await asyncio.wait_for(
self._do_retrieve(tenant_id, query, metadata_filter, top_k),
self._do_retrieve(tenant_id, query, metadata_filter, top_k, step_kb_config),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
@ -323,18 +738,36 @@ class KbSearchDynamicTool:
query: str,
metadata_filter: dict[str, Any] | None = None,
top_k: int = DEFAULT_TOP_K,
step_kb_config: StepKbConfig | None = None,
) -> list[dict[str, Any]]:
"""执行实际检索。"""
"""执行实际检索。[Step-KB-Binding] 支持步骤级别的知识库约束。"""
if self._vector_retriever is None:
from app.services.retrieval.vector_retriever import get_vector_retriever
self._vector_retriever = await get_vector_retriever()
from app.services.retrieval.base import RetrievalContext
# [Step-KB-Binding] 确定要检索的知识库范围
kb_ids = None
if step_kb_config:
# 如果配置了 allowed_kb_ids则只检索这些知识库
if step_kb_config.allowed_kb_ids:
kb_ids = step_kb_config.allowed_kb_ids
logger.info(
f"[Step-KB-Binding] Restricting KB search to: {kb_ids}"
)
# 如果只配置了 preferred_kb_ids优先检索这些知识库
elif step_kb_config.preferred_kb_ids:
kb_ids = step_kb_config.preferred_kb_ids
logger.info(
f"[Step-KB-Binding] Preferring KB search in: {kb_ids}"
)
ctx = RetrievalContext(
tenant_id=tenant_id,
query=query,
metadata=metadata_filter,
metadata_filter=metadata_filter,
kb_ids=kb_ids,
)
result = await self._vector_retriever.retrieve(ctx)
@ -347,6 +780,7 @@ class KbSearchDynamicTool:
"content": hit.text,
"score": hit.score,
"metadata": hit.metadata,
"kb_id": hit.metadata.get("kb_id"),
})
return hits[:top_k]
@ -384,6 +818,7 @@ async def create_kb_search_dynamic_handler(
scene: str = "open_consult",
top_k: int = DEFAULT_TOP_K,
context: dict[str, Any] | None = None,
**kwargs, # 接受系统注入的额外参数user_id, session_id 等)
) -> dict[str, Any]:
"""
KB 动态检索 handler
@ -391,9 +826,9 @@ async def create_kb_search_dynamic_handler(
Args:
query: 检索查询
tenant_id: 租户 ID
scene: 场景标识
scene: 场景标识默认值会被 context 中的 scene 覆盖
top_k: 返回数量
context: 上下文
context: 上下文包含 scene 等过滤字段
Returns:
检索结果字典
@ -442,6 +877,7 @@ def register_kb_search_dynamic_tool(
scene: str = "open_consult",
top_k: int = DEFAULT_TOP_K,
context: dict[str, Any] | None = None,
**kwargs, # 接受系统注入的额外参数user_id, session_id 等)
) -> dict[str, Any]:
tool = KbSearchDynamicTool(
session=session,

View File

@ -542,9 +542,9 @@ def register_memory_recall_tool(
cfg = config or MemoryRecallConfig()
async def memory_recall_handler(
tenant_id: str,
user_id: str,
session_id: str,
tenant_id: str = "",
user_id: str = "",
session_id: str = "",
recall_scope: list[str] | None = None,
max_recent_messages: int | None = None,
) -> dict[str, Any]:
@ -554,6 +554,7 @@ def register_memory_recall_tool(
timeout_governor=timeout_governor,
config=cfg,
)
result = await tool.execute(
tenant_id=tenant_id,
user_id=user_id,
@ -587,12 +588,9 @@ def register_memory_recall_tool(
"recall_scope": {"type": "array", "description": "召回范围,例如 profile/facts/preferences/summary/slots"},
"max_recent_messages": {"type": "integer", "description": "历史回填窗口大小"}
},
"required": ["tenant_id", "user_id", "session_id"]
"required": []
},
"example_action_input": {
"tenant_id": "default",
"user_id": "u_10086",
"session_id": "s_abc_001",
"recall_scope": ["profile", "facts", "preferences", "summary", "slots"],
"max_recent_messages": 8
},

View File

@ -0,0 +1,281 @@
"""
Metadata Discovery Tool for Mid Platform.
[AC-MARH-XX] 元数据发现工具用于查询当前可用的元数据字段及其常见值
核心特性
- 列出当前知识库文档中使用的元数据字段
- 返回每个字段的常见取值从现有文档中聚合
- 支持按知识库过滤
- 返回字段定义信息类型用途说明等
"""
from __future__ import annotations
import asyncio
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import Document, MetadataFieldDefinition, MetadataFieldStatus
from app.services.mid.timeout_governor import TimeoutGovernor
if TYPE_CHECKING:
from app.services.mid.tool_registry import ToolRegistry
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT_MS = 2000
DEFAULT_TOP_VALUES = 10
@dataclass
class MetadataFieldDiscoveryConfig:
"""Configuration for metadata field discovery tool."""
timeout_ms: int = DEFAULT_TIMEOUT_MS
top_values_count: int = DEFAULT_TOP_VALUES
@dataclass
class MetadataFieldInfo:
"""Information about a metadata field."""
field_key: str
field_type: str = "string"
label: str = ""
description: str | None = None
common_values: list[str] = field(default_factory=list)
value_count: int = 0
is_filterable: bool = True
options: list[str] | None = None
@dataclass
class MetadataDiscoveryResult:
"""Result of metadata discovery."""
success: bool
fields: list[MetadataFieldInfo] = field(default_factory=list)
total_documents: int = 0
error: str | None = None
duration_ms: int = 0
class MetadataDiscoveryTool:
"""
[AC-MARH-XX] 元数据发现工具
用于查询当前知识库文档中使用的元数据字段及其常见值
帮助 AI 了解可用的过滤字段从而更好地构造搜索请求
"""
def __init__(
self,
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: MetadataFieldDiscoveryConfig | None = None,
):
self._session = session
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._config = config or MetadataFieldDiscoveryConfig()
async def execute(
self,
tenant_id: str,
kb_id: str | None = None,
include_values: bool = True,
top_n: int | None = None,
) -> MetadataDiscoveryResult:
"""
Execute metadata discovery.
Args:
tenant_id: Tenant ID
kb_id: Optional knowledge base ID to filter
include_values: Whether to include common values (default True)
top_n: Number of top values to return per field (default from config)
Returns:
MetadataDiscoveryResult with field information
"""
start_time = asyncio.get_event_loop().time()
try:
top_n = top_n or self._config.top_values_count
field_definitions = await self._get_field_definitions(tenant_id)
document_metadata = await self._get_document_metadata(tenant_id, kb_id)
total_docs = len(document_metadata)
field_values: dict[str, Counter] = {}
for doc_meta in document_metadata:
if not doc_meta:
continue
for key, value in doc_meta.items():
if key not in field_values:
field_values[key] = Counter()
str_value = str(value) if value is not None else ""
field_values[key].update([str_value])
fields: list[MetadataFieldInfo] = []
for field_key, values_counter in field_values.items():
field_def = field_definitions.get(field_key)
common_values = []
if include_values:
most_common = values_counter.most_common(top_n)
common_values = [v for v, _ in most_common if v]
field_info = MetadataFieldInfo(
field_key=field_key,
field_type=field_def.type if field_def else "string",
label=field_def.label if field_def else field_key,
description=field_def.usage_description if field_def else None,
common_values=common_values,
value_count=len(values_counter),
is_filterable=field_def.is_filterable if field_def else True,
options=field_def.options if field_def else None,
)
fields.append(field_info)
fields.sort(key=lambda f: f.value_count, reverse=True)
duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000)
logger.info(
f"[MetadataDiscovery] Discovered {len(fields)} fields from {total_docs} documents, "
f"duration={duration_ms}ms"
)
return MetadataDiscoveryResult(
success=True,
fields=fields,
total_documents=total_docs,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"[MetadataDiscovery] Discovery failed: {e}")
return MetadataDiscoveryResult(
success=False,
error=str(e),
)
async def _get_field_definitions(
self,
tenant_id: str,
) -> dict[str, MetadataFieldDefinition]:
"""Get field definitions for tenant."""
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.tenant_id == tenant_id,
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value,
)
result = await self._session.execute(stmt)
definitions = result.scalars().all()
return {d.field_key: d for d in definitions}
async def _get_document_metadata(
self,
tenant_id: str,
kb_id: str | None = None,
) -> list[dict[str, Any]]:
"""Get all document metadata for tenant."""
stmt = select(Document.doc_metadata).where(
Document.tenant_id == tenant_id,
)
if kb_id:
stmt = stmt.where(Document.kb_id == kb_id)
result = await self._session.execute(stmt)
rows = result.scalars().all()
return [row for row in rows if row]
def register_metadata_discovery_tool(
registry: "ToolRegistry",
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: MetadataFieldDiscoveryConfig | None = None,
) -> None:
"""Register metadata discovery tool to registry."""
from app.services.mid.tool_registry import ToolType
cfg = config or MetadataFieldDiscoveryConfig()
async def metadata_discovery_handler(
tenant_id: str = "",
kb_id: str | None = None,
include_values: bool = True,
top_n: int | None = None,
**kwargs, # 接受系统注入的额外参数user_id, session_id 等)
) -> dict[str, Any]:
"""Metadata discovery tool handler."""
tool = MetadataDiscoveryTool(
session=session,
timeout_governor=timeout_governor,
config=cfg,
)
result = await tool.execute(
tenant_id=tenant_id,
kb_id=kb_id,
include_values=include_values,
top_n=top_n,
)
# 将 dataclass 转换为 dict
return {
"success": result.success,
"fields": [
{
"field_key": f.field_key,
"field_type": f.field_type,
"label": f.label,
"description": f.description,
"common_values": f.common_values,
"value_count": f.value_count,
"is_filterable": f.is_filterable,
"options": f.options,
}
for f in result.fields
],
"total_documents": result.total_documents,
"error": result.error,
"duration_ms": result.duration_ms,
}
registry.register(
name="list_document_metadata_fields",
description="列出当前知识库文档中使用的元数据字段及其常见取值,用于后续的知识库搜索过滤",
handler=metadata_discovery_handler,
tool_type=ToolType.INTERNAL,
version="1.0.0",
auth_required=False,
timeout_ms=cfg.timeout_ms,
enabled=True,
metadata={
"when_to_use": "当需要了解知识库中有哪些可用的元数据过滤字段时使用。",
"when_not_to_use": "当已知可用的过滤字段,或不需要元数据过滤时不需要调用。",
"parameters": {
"type": "object",
"properties": {
"tenant_id": {"type": "string", "description": "租户 ID"},
"kb_id": {"type": "string", "description": "知识库 ID可选用于限定范围"},
"include_values": {"type": "boolean", "description": "是否包含常见值列表,默认 true"},
"top_n": {"type": "integer", "description": "每个字段返回的常见值数量,默认 10"},
},
"required": [],
},
"example_action_input": {
"include_values": True,
"top_n": 5,
},
"result_interpretation": "fields 数组包含每个字段的详细信息common_values 是该字段在文档中的常见取值value_count 表示该字段在多少文档中出现。",
},
)
logger.info("[MetadataDiscovery] Tool registered: list_document_metadata_fields")

View File

@ -165,21 +165,56 @@ class MetadataFilterBuilder:
) -> list[FilterFieldInfo]:
"""
[AC-MRS-11] 获取可过滤的字段定义
优先从 Redis 缓存获取未缓存则从数据库查询并缓存
条件
- 状态=生效 (active)
- field_roles 包含 resource_filter
"""
import time
start_time = time.time()
# 1. 尝试从缓存获取
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
cached_fields = await cache_service.get_fields(tenant_id)
if cached_fields is not None:
# 缓存命中,直接返回
logger.info(
f"[AC-MRS-11] Cache hit: Retrieved {len(cached_fields)} fields "
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
)
return [
FilterFieldInfo(
field_key=f["field_key"],
label=f["label"],
field_type=f["field_type"],
required=f["required"],
options=f.get("options"),
default_value=f.get("default_value"),
is_filterable=f["is_filterable"],
)
for f in cached_fields
]
# 2. 缓存未命中,从数据库查询
logger.info(f"[AC-MRS-11] Cache miss: Querying database for tenant={tenant_id}")
db_start = time.time()
fields = await self._role_provider.get_fields_by_role(
tenant_id=tenant_id,
role=FieldRole.RESOURCE_FILTER.value,
)
db_time = (time.time() - db_start) * 1000
logger.info(
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}"
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields "
f"for tenant={tenant_id} from DB in {db_time:.2f}ms"
)
return [
# 3. 转换为 FilterFieldInfo 列表
filter_fields = [
FilterFieldInfo(
field_key=f.field_key,
label=f.label,
@ -192,6 +227,29 @@ class MetadataFilterBuilder:
for f in fields
]
# 4. 缓存到 Redis
cache_data = [
{
"field_key": f.field_key,
"label": f.label,
"field_type": f.field_type,
"required": f.required,
"options": f.options,
"default_value": f.default_value,
"is_filterable": f.is_filterable,
}
for f in filter_fields
]
await cache_service.set_fields(tenant_id, cache_data)
total_time = (time.time() - start_time) * 1000
logger.info(
f"[AC-MRS-11] Total time for tenant={tenant_id}: {total_time:.2f}ms "
f"(DB: {db_time:.2f}ms)"
)
return filter_fields
def _build_field_filter(
self,
field_info: FilterFieldInfo,

View File

@ -175,7 +175,9 @@ class RoleBasedFieldProvider:
"slot_key": slot.slot_key,
"type": slot.type,
"required": slot.required,
# [AC-MRS-07-UPGRADE] 返回新旧字段
"extract_strategy": slot.extract_strategy,
"extract_strategies": slot.extract_strategies,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
@ -217,7 +219,9 @@ class RoleBasedFieldProvider:
"slot_key": field.field_key,
"type": field.type,
"required": field.required,
# [AC-MRS-07-UPGRADE] 返回新旧字段
"extract_strategy": None,
"extract_strategies": None,
"validation_rule": None,
"ask_back_prompt": None,
"default_value": field.default_value,

View File

@ -53,6 +53,7 @@ class RuntimeContext:
def to_trace_info(self) -> TraceInfo:
"""转换为 TraceInfo。"""
try:
return TraceInfo(
mode=self.mode,
intent=self.intent,
@ -71,6 +72,16 @@ class RuntimeContext:
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls if self.tool_calls else None,
)
except Exception as e:
import traceback
logger.error(
f"[RuntimeObserver] Failed to create TraceInfo: {e}\n"
f"Exception type: {type(e).__name__}\n"
f"Context: mode={self.mode}, request_id={self.request_id}, "
f"generation_id={self.generation_id}\n"
f"Traceback:\n{traceback.format_exc()}"
)
raise
class RuntimeObserver:

View File

@ -0,0 +1,423 @@
"""
Scene Slot Bundle Loader Service.
[AC-SCENE-SLOT-02] 运行时场景槽位包加载器
[AC-SCENE-SLOT-03] 支持缓存层
职责
1. 根据场景标识加载槽位包配置
2. 聚合槽位定义详情
3. 计算缺失槽位
4. 生成追问提示
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import SceneSlotBundleStatus
from app.services.scene_slot_bundle_service import SceneSlotBundleService
from app.services.slot_definition_service import SlotDefinitionService
from app.services.cache.scene_slot_bundle_cache import (
CachedSceneSlotBundle,
get_scene_slot_bundle_cache,
)
logger = logging.getLogger(__name__)
@dataclass
class SlotInfo:
"""槽位信息"""
slot_key: str
type: str
required: bool
ask_back_prompt: str | None = None
validation_rule: str | None = None
linked_field_id: str | None = None
default_value: Any = None
@dataclass
class SceneSlotContext:
"""
[AC-SCENE-SLOT-02] 场景槽位上下文
运行时使用的场景槽位包信息
"""
scene_key: str
scene_name: str
required_slots: list[SlotInfo] = field(default_factory=list)
optional_slots: list[SlotInfo] = field(default_factory=list)
slot_priority: list[str] = field(default_factory=list)
completion_threshold: float = 1.0
ask_back_order: str = "priority"
def get_all_slot_keys(self) -> list[str]:
"""获取所有槽位键名"""
return [s.slot_key for s in self.required_slots] + [s.slot_key for s in self.optional_slots]
def get_required_slot_keys(self) -> list[str]:
"""获取必填槽位键名"""
return [s.slot_key for s in self.required_slots]
def get_optional_slot_keys(self) -> list[str]:
"""获取可选槽位键名"""
return [s.slot_key for s in self.optional_slots]
def get_missing_slots(
self,
filled_slots: dict[str, Any],
) -> list[dict[str, Any]]:
"""
获取缺失的必填槽位信息
Args:
filled_slots: 已填充的槽位值
Returns:
缺失槽位信息列表
"""
missing = []
for slot_info in self.required_slots:
if slot_info.slot_key not in filled_slots:
missing.append({
"slot_key": slot_info.slot_key,
"type": slot_info.type,
"required": True,
"ask_back_prompt": slot_info.ask_back_prompt,
"validation_rule": slot_info.validation_rule,
"linked_field_id": slot_info.linked_field_id,
})
return missing
def get_ordered_missing_slots(
self,
filled_slots: dict[str, Any],
) -> list[dict[str, Any]]:
"""
按优先级顺序获取缺失的必填槽位
Args:
filled_slots: 已填充的槽位值
Returns:
按优先级排序的缺失槽位信息列表
"""
missing = self.get_missing_slots(filled_slots)
if not missing:
return []
if self.ask_back_order == "required_first":
return missing
if self.ask_back_order == "priority" and self.slot_priority:
priority_map = {slot_key: idx for idx, slot_key in enumerate(self.slot_priority)}
missing.sort(key=lambda x: priority_map.get(x["slot_key"], 999))
return missing
def get_completion_ratio(
self,
filled_slots: dict[str, Any],
) -> float:
"""
计算完成比例
Args:
filled_slots: 已填充的槽位值
Returns:
完成比例 (0.0 - 1.0)
"""
if not self.required_slots:
return 1.0
filled_count = sum(
1 for slot_info in self.required_slots
if slot_info.slot_key in filled_slots
)
return filled_count / len(self.required_slots)
def is_complete(
self,
filled_slots: dict[str, Any],
) -> bool:
"""
检查是否完成
Args:
filled_slots: 已填充的槽位值
Returns:
是否达到完成阈值
"""
return self.get_completion_ratio(filled_slots) >= self.completion_threshold
class SceneSlotBundleLoader:
"""
[AC-SCENE-SLOT-02] 场景槽位包加载器
[AC-SCENE-SLOT-03] 支持缓存层
运行时加载场景槽位包配置
"""
def __init__(self, session: AsyncSession, use_cache: bool = True):
self._session = session
self._bundle_service = SceneSlotBundleService(session)
self._slot_service = SlotDefinitionService(session)
self._use_cache = use_cache
self._cache = get_scene_slot_bundle_cache() if use_cache else None
async def load_scene_context(
self,
tenant_id: str,
scene_key: str,
) -> SceneSlotContext | None:
"""
加载场景槽位上下文
Args:
tenant_id: 租户 ID
scene_key: 场景标识
Returns:
场景槽位上下文或 None
"""
# [AC-SCENE-SLOT-03] 尝试从缓存获取
if self._use_cache and self._cache:
cached = await self._cache.get(tenant_id, scene_key)
if cached and cached.status == SceneSlotBundleStatus.ACTIVE.value:
logger.debug(
f"[AC-SCENE-SLOT-03] Cache hit for scene: {scene_key}"
)
return await self._build_context_from_cached(tenant_id, cached)
bundle = await self._bundle_service.get_active_bundle_by_scene(
tenant_id=tenant_id,
scene_key=scene_key,
)
if not bundle:
logger.debug(
f"[AC-SCENE-SLOT-02] No active bundle found for scene: {scene_key}"
)
return None
# [AC-SCENE-SLOT-03] 写入缓存
if self._use_cache and self._cache:
cached_bundle = CachedSceneSlotBundle(
scene_key=bundle.scene_key,
scene_name=bundle.scene_name,
description=bundle.description,
required_slots=bundle.required_slots,
optional_slots=bundle.optional_slots,
slot_priority=bundle.slot_priority,
completion_threshold=bundle.completion_threshold,
ask_back_order=bundle.ask_back_order,
status=bundle.status,
version=bundle.version,
)
await self._cache.set(tenant_id, scene_key, cached_bundle)
return await self._build_context_from_bundle(tenant_id, bundle)
async def _build_context_from_cached(
self,
tenant_id: str,
cached: CachedSceneSlotBundle,
) -> SceneSlotContext:
"""从缓存构建场景槽位上下文"""
all_slots = await self._slot_service.list_slot_definitions(tenant_id)
slot_map = {slot.slot_key: slot for slot in all_slots}
return self._build_context(cached, slot_map)
async def _build_context_from_bundle(
self,
tenant_id: str,
bundle: Any,
) -> SceneSlotContext:
"""从数据库模型构建场景槽位上下文"""
all_slots = await self._slot_service.list_slot_definitions(tenant_id)
slot_map = {slot.slot_key: slot for slot in all_slots}
cached = CachedSceneSlotBundle(
scene_key=bundle.scene_key,
scene_name=bundle.scene_name,
description=bundle.description,
required_slots=bundle.required_slots,
optional_slots=bundle.optional_slots,
slot_priority=bundle.slot_priority,
completion_threshold=bundle.completion_threshold,
ask_back_order=bundle.ask_back_order,
status=bundle.status,
version=bundle.version,
)
return self._build_context(cached, slot_map)
def _build_context(
self,
cached: CachedSceneSlotBundle,
slot_map: dict[str, Any],
) -> SceneSlotContext:
"""构建场景槽位上下文"""
required_slot_infos = []
for slot_key in cached.required_slots:
if slot_key in slot_map:
slot_def = slot_map[slot_key]
required_slot_infos.append(SlotInfo(
slot_key=slot_def.slot_key,
type=slot_def.type,
required=True,
ask_back_prompt=slot_def.ask_back_prompt,
validation_rule=slot_def.validation_rule,
linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None,
default_value=slot_def.default_value,
))
else:
logger.warning(
f"[AC-SCENE-SLOT-02] Required slot not found: {slot_key}"
)
optional_slot_infos = []
for slot_key in cached.optional_slots:
if slot_key in slot_map:
slot_def = slot_map[slot_key]
optional_slot_infos.append(SlotInfo(
slot_key=slot_def.slot_key,
type=slot_def.type,
required=slot_def.required,
ask_back_prompt=slot_def.ask_back_prompt,
validation_rule=slot_def.validation_rule,
linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None,
default_value=slot_def.default_value,
))
else:
logger.warning(
f"[AC-SCENE-SLOT-02] Optional slot not found: {slot_key}"
)
context = SceneSlotContext(
scene_key=cached.scene_key,
scene_name=cached.scene_name,
required_slots=required_slot_infos,
optional_slots=optional_slot_infos,
slot_priority=cached.slot_priority or [],
completion_threshold=cached.completion_threshold,
ask_back_order=cached.ask_back_order,
)
logger.info(
f"[AC-SCENE-SLOT-02] Loaded scene context: scene={cached.scene_key}, "
f"required={len(required_slot_infos)}, optional={len(optional_slot_infos)}, "
f"threshold={cached.completion_threshold}"
)
return context
async def invalidate_cache(
self,
tenant_id: str,
scene_key: str,
) -> bool:
"""
[AC-SCENE-SLOT-03] 使缓存失效
Args:
tenant_id: 租户 ID
scene_key: 场景标识
Returns:
是否成功
"""
if self._cache:
return await self._cache.invalidate_on_update(tenant_id, scene_key)
return True
async def get_missing_slots_for_scene(
self,
tenant_id: str,
scene_key: str,
filled_slots: dict[str, Any],
) -> list[dict[str, Any]]:
"""
获取场景缺失的必填槽位
Args:
tenant_id: 租户 ID
scene_key: 场景标识
filled_slots: 已填充的槽位值
Returns:
缺失槽位信息列表
"""
context = await self.load_scene_context(tenant_id, scene_key)
if not context:
return []
return context.get_ordered_missing_slots(filled_slots)
async def generate_ask_back_prompt(
self,
tenant_id: str,
scene_key: str,
filled_slots: dict[str, Any],
max_slots: int = 2,
) -> str | None:
"""
生成追问提示
Args:
tenant_id: 租户 ID
scene_key: 场景标识
filled_slots: 已填充的槽位值
max_slots: 最多追问的槽位数量
Returns:
追问提示或 None
"""
missing_slots = await self.get_missing_slots_for_scene(
tenant_id=tenant_id,
scene_key=scene_key,
filled_slots=filled_slots,
)
if not missing_slots:
return None
context = await self.load_scene_context(tenant_id, scene_key)
ask_back_order = context.ask_back_order if context else "priority"
if ask_back_order == "parallel":
prompts = []
for missing in missing_slots[:max_slots]:
if missing.get("ask_back_prompt"):
prompts.append(missing["ask_back_prompt"])
else:
slot_key = missing.get("slot_key", "相关信息")
prompts.append(f"您的{slot_key}")
if len(prompts) == 1:
return prompts[0]
elif len(prompts) == 2:
return f"为了更好地为您服务,请告诉我{prompts[0]}{prompts[1]}"
else:
all_but_last = "".join(prompts[:-1])
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}"
else:
first_missing = missing_slots[0]
ask_back_prompt = first_missing.get("ask_back_prompt")
if ask_back_prompt:
return ask_back_prompt
slot_key = first_missing.get("slot_key", "相关信息")
return f"为了更好地为您提供帮助,请告诉我您的{slot_key}"

View File

@ -0,0 +1,334 @@
"""
Scene Slot Bundle Metrics Service.
[AC-SCENE-SLOT-04] 场景槽位包监控指标
职责
1. 收集场景槽位包相关的监控指标
2. 提供告警检测接口
3. 支持指标导出
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from collections import defaultdict
import threading
logger = logging.getLogger(__name__)
@dataclass
class SceneSlotMetricPoint:
"""单个指标数据点"""
timestamp: datetime
tenant_id: str
scene_key: str
metric_name: str
value: float
tags: dict[str, str] = field(default_factory=dict)
@dataclass
class SceneSlotMetricsSummary:
"""指标汇总"""
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
missing_slots_triggered: int = 0
ask_back_triggered: int = 0
scene_not_configured: int = 0
slot_not_found: int = 0
avg_completion_ratio: float = 0.0
@property
def cache_hit_rate(self) -> float:
if self.total_requests == 0:
return 0.0
return self.cache_hits / self.total_requests
class SceneSlotMetricsCollector:
"""
[AC-SCENE-SLOT-04] 场景槽位包指标收集器
收集以下指标
- scene_slot_requests_total: 场景槽位请求总数
- scene_slot_cache_hits: 缓存命中次数
- scene_slot_cache_misses: 缓存未命中次数
- scene_slot_missing_triggered: 缺失槽位触发次数
- scene_slot_ask_back_triggered: 追问触发次数
- scene_slot_not_configured: 场景未配置次数
- scene_slot_not_found: 槽位未找到次数
- scene_slot_completion_ratio: 槽位完成比例
"""
def __init__(self, max_points: int = 10000):
self._max_points = max_points
self._points: list[SceneSlotMetricPoint] = []
self._counters: dict[str, int] = defaultdict(int)
self._lock = threading.Lock()
self._start_time = datetime.utcnow()
def record(
self,
tenant_id: str,
scene_key: str,
metric_name: str,
value: float = 1.0,
tags: dict[str, str] | None = None,
) -> None:
"""
记录指标
Args:
tenant_id: 租户 ID
scene_key: 场景标识
metric_name: 指标名称
value: 指标值
tags: 额外标签
"""
point = SceneSlotMetricPoint(
timestamp=datetime.utcnow(),
tenant_id=tenant_id,
scene_key=scene_key,
metric_name=metric_name,
value=value,
tags=tags or {},
)
with self._lock:
self._points.append(point)
if len(self._points) > self._max_points:
self._points = self._points[-self._max_points:]
counter_key = f"{tenant_id}:{scene_key}:{metric_name}"
self._counters[counter_key] += 1
def record_cache_hit(self, tenant_id: str, scene_key: str) -> None:
"""记录缓存命中"""
self.record(tenant_id, scene_key, "cache_hit")
def record_cache_miss(self, tenant_id: str, scene_key: str) -> None:
"""记录缓存未命中"""
self.record(tenant_id, scene_key, "cache_miss")
def record_missing_slots(self, tenant_id: str, scene_key: str, count: int = 1) -> None:
"""记录缺失槽位触发"""
self.record(tenant_id, scene_key, "missing_slots_triggered", float(count))
def record_ask_back(self, tenant_id: str, scene_key: str) -> None:
"""记录追问触发"""
self.record(tenant_id, scene_key, "ask_back_triggered")
def record_scene_not_configured(self, tenant_id: str, scene_key: str) -> None:
"""记录场景未配置"""
self.record(tenant_id, scene_key, "scene_not_configured")
def record_slot_not_found(self, tenant_id: str, scene_key: str, slot_key: str) -> None:
"""记录槽位未找到"""
self.record(tenant_id, scene_key, "slot_not_found", tags={"slot_key": slot_key})
def record_completion_ratio(self, tenant_id: str, scene_key: str, ratio: float) -> None:
"""记录槽位完成比例"""
self.record(tenant_id, scene_key, "completion_ratio", ratio)
def get_summary(self, tenant_id: str | None = None) -> SceneSlotMetricsSummary:
"""
获取指标汇总
Args:
tenant_id: 租户 ID可选 None 时返回所有租户的汇总
Returns:
指标汇总
"""
summary = SceneSlotMetricsSummary()
with self._lock:
points = self._points.copy()
for point in points:
if tenant_id and point.tenant_id != tenant_id:
continue
summary.total_requests += 1
if point.metric_name == "cache_hit":
summary.cache_hits += 1
elif point.metric_name == "cache_miss":
summary.cache_misses += 1
elif point.metric_name == "missing_slots_triggered":
summary.missing_slots_triggered += int(point.value)
elif point.metric_name == "ask_back_triggered":
summary.ask_back_triggered += 1
elif point.metric_name == "scene_not_configured":
summary.scene_not_configured += 1
elif point.metric_name == "slot_not_found":
summary.slot_not_found += 1
elif point.metric_name == "completion_ratio":
summary.avg_completion_ratio = (
summary.avg_completion_ratio * summary.total_requests + point.value
) / (summary.total_requests + 1) if summary.total_requests > 0 else point.value
return summary
def get_metrics_by_scene(
self,
tenant_id: str,
scene_key: str,
) -> dict[str, Any]:
"""
获取特定场景的指标
Args:
tenant_id: 租户 ID
scene_key: 场景标识
Returns:
指标字典
"""
metrics = {
"requests": 0,
"cache_hits": 0,
"cache_misses": 0,
"missing_slots_triggered": 0,
"ask_back_triggered": 0,
"slot_not_found": 0,
"avg_completion_ratio": 0.0,
}
with self._lock:
points = [p for p in self._points if p.tenant_id == tenant_id and p.scene_key == scene_key]
if not points:
return metrics
completion_ratios = []
for point in points:
metrics["requests"] += 1
if point.metric_name in metrics:
if point.metric_name == "completion_ratio":
completion_ratios.append(point.value)
else:
metrics[point.metric_name] += int(point.value)
if completion_ratios:
metrics["avg_completion_ratio"] = sum(completion_ratios) / len(completion_ratios)
return metrics
def check_alerts(
self,
tenant_id: str,
scene_key: str,
thresholds: dict[str, float] | None = None,
) -> list[dict[str, Any]]:
"""
检查告警条件
Args:
tenant_id: 租户 ID
scene_key: 场景标识
thresholds: 告警阈值配置
Returns:
告警列表
"""
default_thresholds = {
"cache_hit_rate_low": 0.5,
"missing_slots_rate_high": 0.3,
"scene_not_configured_rate_high": 0.1,
}
effective_thresholds = {**default_thresholds, **(thresholds or {})}
metrics = self.get_metrics_by_scene(tenant_id, scene_key)
alerts = []
if metrics["requests"] > 0:
cache_hit_rate = metrics["cache_hits"] / metrics["requests"]
if cache_hit_rate < effective_thresholds["cache_hit_rate_low"]:
alerts.append({
"alert_type": "cache_hit_rate_low",
"severity": "warning",
"message": f"场景 {scene_key} 的缓存命中率 ({cache_hit_rate:.2%}) 低于阈值 ({effective_thresholds['cache_hit_rate_low']:.0%})",
"current_value": cache_hit_rate,
"threshold": effective_thresholds["cache_hit_rate_low"],
"suggestion": "检查场景槽位包配置是否频繁变更,或增加缓存 TTL",
})
missing_slots_rate = metrics["missing_slots_triggered"] / metrics["requests"]
if missing_slots_rate > effective_thresholds["missing_slots_rate_high"]:
alerts.append({
"alert_type": "missing_slots_rate_high",
"severity": "warning",
"message": f"场景 {scene_key} 的缺失槽位触发率 ({missing_slots_rate:.2%}) 高于阈值 ({effective_thresholds['missing_slots_rate_high']:.0%})",
"current_value": missing_slots_rate,
"threshold": effective_thresholds["missing_slots_rate_high"],
"suggestion": "检查槽位配置是否合理,或优化槽位提取策略",
})
scene_not_configured_rate = metrics.get("scene_not_configured", 0) / metrics["requests"]
if scene_not_configured_rate > effective_thresholds["scene_not_configured_rate_high"]:
alerts.append({
"alert_type": "scene_not_configured_rate_high",
"severity": "error",
"message": f"场景 {scene_key} 未配置率 ({scene_not_configured_rate:.2%}) 高于阈值 ({effective_thresholds['scene_not_configured_rate_high']:.0%})",
"current_value": scene_not_configured_rate,
"threshold": effective_thresholds["scene_not_configured_rate_high"],
"suggestion": "请为该场景创建场景槽位包配置",
})
return alerts
def export_metrics(self, tenant_id: str | None = None) -> dict[str, Any]:
"""
导出指标数据
Args:
tenant_id: 租户 ID可选
Returns:
指标数据字典
"""
with self._lock:
points = [
{
"timestamp": p.timestamp.isoformat(),
"tenant_id": p.tenant_id,
"scene_key": p.scene_key,
"metric_name": p.metric_name,
"value": p.value,
"tags": p.tags,
}
for p in self._points
if tenant_id is None or p.tenant_id == tenant_id
]
return {
"start_time": self._start_time.isoformat(),
"end_time": datetime.utcnow().isoformat(),
"total_points": len(points),
"points": points,
"summary": self.get_summary(tenant_id).__dict__,
}
def reset(self) -> None:
"""重置指标收集器"""
with self._lock:
self._points = []
self._counters = defaultdict(int)
self._start_time = datetime.utcnow()
_metrics_collector: SceneSlotMetricsCollector | None = None
def get_scene_slot_metrics_collector() -> SceneSlotMetricsCollector:
"""获取场景槽位指标收集器实例"""
global _metrics_collector
if _metrics_collector is None:
_metrics_collector = SceneSlotMetricsCollector()
return _metrics_collector

View File

@ -0,0 +1,500 @@
"""
Slot Backfill Service.
槽位回填服务 - 处理槽位值的提取校验确认写回
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认
职责
1. 从用户回复提取候选槽位值
2. 调用 SlotManager 校验并归一化
3. 校验失败返回 ask_back_prompt 二次追问
4. 校验通过写入状态并标记 source/confidence
5. 对低置信度值增加确认话术
"""
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_manager import SlotManager, SlotWriteResult
from app.services.mid.slot_state_aggregator import SlotStateAggregator, SlotState
from app.services.mid.slot_strategy_executor import (
ExtractContext,
SlotStrategyExecutor,
StrategyChainResult,
)
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
class BackfillStatus(str, Enum):
"""回填状态"""
SUCCESS = "success"
VALIDATION_FAILED = "validation_failed"
EXTRACTION_FAILED = "extraction_failed"
NEEDS_CONFIRMATION = "needs_confirmation"
NO_CANDIDATES = "no_candidates"
@dataclass
class BackfillResult:
"""
回填结果
Attributes:
status: 回填状态
slot_key: 槽位键名
value: 最终值校验通过后
normalized_value: 归一化后的值
source: 值来源
confidence: 置信度
error_message: 错误信息
ask_back_prompt: 追问提示
confirmation_prompt: 确认提示低置信度时
updated_state: 更新后的槽位状态
"""
status: BackfillStatus
slot_key: str | None = None
value: Any = None
normalized_value: Any = None
source: str = "unknown"
confidence: float = 0.0
error_message: str | None = None
ask_back_prompt: str | None = None
confirmation_prompt: str | None = None
updated_state: SlotState | None = None
def is_success(self) -> bool:
return self.status == BackfillStatus.SUCCESS
def needs_ask_back(self) -> bool:
return self.status in (
BackfillStatus.VALIDATION_FAILED,
BackfillStatus.EXTRACTION_FAILED,
)
def needs_confirmation(self) -> bool:
return self.status == BackfillStatus.NEEDS_CONFIRMATION
def to_dict(self) -> dict[str, Any]:
return {
"status": self.status.value,
"slot_key": self.slot_key,
"value": self.value,
"normalized_value": self.normalized_value,
"source": self.source,
"confidence": self.confidence,
"error_message": self.error_message,
"ask_back_prompt": self.ask_back_prompt,
"confirmation_prompt": self.confirmation_prompt,
}
@dataclass
class BatchBackfillResult:
"""批量回填结果"""
results: list[BackfillResult] = field(default_factory=list)
success_count: int = 0
failed_count: int = 0
confirmation_needed_count: int = 0
def add_result(self, result: BackfillResult) -> None:
self.results.append(result)
if result.is_success():
self.success_count += 1
elif result.needs_confirmation():
self.confirmation_needed_count += 1
else:
self.failed_count += 1
def get_ask_back_prompts(self) -> list[str]:
"""获取所有追问提示"""
return [
r.ask_back_prompt
for r in self.results
if r.ask_back_prompt
]
def get_confirmation_prompts(self) -> list[str]:
"""获取所有确认提示"""
return [
r.confirmation_prompt
for r in self.results
if r.confirmation_prompt
]
class SlotBackfillService:
"""
[AC-MRS-SLOT-BACKFILL-01] 槽位回填服务
处理槽位值的提取校验确认写回流程
1. 从用户回复提取候选槽位值
2. SlotManager 校验并归一化
3. 校验失败返回 ask_back_prompt 二次追问
4. 校验通过写入状态并标记 source/confidence
5. 对低置信度值增加确认话术
"""
CONFIDENCE_THRESHOLD_LOW = 0.5
CONFIDENCE_THRESHOLD_HIGH = 0.8
def __init__(
self,
session: AsyncSession,
tenant_id: str,
session_id: str | None = None,
slot_manager: SlotManager | None = None,
strategy_executor: SlotStrategyExecutor | None = None,
):
self._session = session
self._tenant_id = tenant_id
self._session_id = session_id
self._slot_manager = slot_manager or SlotManager(session=session)
self._strategy_executor = strategy_executor or SlotStrategyExecutor()
self._slot_def_service = SlotDefinitionService(session)
self._state_aggregator: SlotStateAggregator | None = None
async def _get_state_aggregator(self) -> SlotStateAggregator:
"""获取状态聚合器"""
if self._state_aggregator is None:
self._state_aggregator = SlotStateAggregator(
session=self._session,
tenant_id=self._tenant_id,
session_id=self._session_id,
)
return self._state_aggregator
async def backfill_single_slot(
self,
slot_key: str,
candidate_value: Any,
source: str = "user_confirmed",
confidence: float = 1.0,
strategies: list[str] | None = None,
) -> BackfillResult:
"""
回填单个槽位
执行流程
1. 如果有提取策略执行提取
2. 校验候选值
3. 根据校验结果决定下一步
4. 写入状态
Args:
slot_key: 槽位键名
candidate_value: 候选值
source: 值来源
confidence: 初始置信度
strategies: 提取策略链可选
Returns:
BackfillResult: 回填结果
"""
final_value = candidate_value
final_source = source
final_confidence = confidence
if strategies:
extracted_result = await self._extract_value(
slot_key=slot_key,
user_input=str(candidate_value),
strategies=strategies,
)
if extracted_result.success:
final_value = extracted_result.final_value
final_source = extracted_result.final_strategy or source
final_confidence = self._get_confidence_for_strategy(final_source)
else:
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
self._tenant_id, slot_key
)
return BackfillResult(
status=BackfillStatus.EXTRACTION_FAILED,
slot_key=slot_key,
value=candidate_value,
source=source,
confidence=confidence,
error_message=extracted_result.steps[-1].failure_reason if extracted_result.steps else "提取失败",
ask_back_prompt=ask_back_prompt,
)
write_result = await self._slot_manager.write_slot(
tenant_id=self._tenant_id,
slot_key=slot_key,
value=final_value,
source=SlotSource(final_source) if final_source in [s.value for s in SlotSource] else SlotSource.USER_CONFIRMED,
confidence=final_confidence,
)
if not write_result.success:
return BackfillResult(
status=BackfillStatus.VALIDATION_FAILED,
slot_key=slot_key,
value=final_value,
source=final_source,
confidence=final_confidence,
error_message=write_result.error.error_message if write_result.error else "校验失败",
ask_back_prompt=write_result.ask_back_prompt,
)
normalized_value = write_result.value
updated_state = None
if self._session_id:
aggregator = await self._get_state_aggregator()
updated_state = await aggregator.update_slot(
slot_key=slot_key,
value=normalized_value,
source=final_source,
confidence=final_confidence,
)
result_status = BackfillStatus.SUCCESS
confirmation_prompt = None
if final_confidence < self.CONFIDENCE_THRESHOLD_LOW:
result_status = BackfillStatus.NEEDS_CONFIRMATION
confirmation_prompt = self._generate_confirmation_prompt(
slot_key, normalized_value
)
return BackfillResult(
status=result_status,
slot_key=slot_key,
value=final_value,
normalized_value=normalized_value,
source=final_source,
confidence=final_confidence,
confirmation_prompt=confirmation_prompt,
updated_state=updated_state,
)
async def backfill_multiple_slots(
self,
candidates: dict[str, Any],
source: str = "user_confirmed",
confidence: float = 1.0,
) -> BatchBackfillResult:
"""
批量回填槽位
Args:
candidates: 候选值字典 {slot_key: value}
source: 值来源
confidence: 初始置信度
Returns:
BatchBackfillResult: 批量回填结果
"""
batch_result = BatchBackfillResult()
for slot_key, value in candidates.items():
result = await self.backfill_single_slot(
slot_key=slot_key,
candidate_value=value,
source=source,
confidence=confidence,
)
batch_result.add_result(result)
return batch_result
async def backfill_from_user_response(
self,
user_response: str,
expected_slots: list[str],
strategies: list[str] | None = None,
) -> BatchBackfillResult:
"""
从用户回复中提取并回填槽位
Args:
user_response: 用户回复文本
expected_slots: 期望提取的槽位列表
strategies: 提取策略链
Returns:
BatchBackfillResult: 批量回填结果
"""
batch_result = BatchBackfillResult()
for slot_key in expected_slots:
slot_def = await self._slot_def_service.get_slot_definition_by_key(
self._tenant_id, slot_key
)
if not slot_def:
continue
extract_strategies = strategies or ["rule", "llm"]
extracted_result = await self._extract_value(
slot_key=slot_key,
user_input=user_response,
strategies=extract_strategies,
slot_type=slot_def.type,
validation_rule=slot_def.validation_rule,
)
if not extracted_result.success:
ask_back_prompt = slot_def.ask_back_prompt or f"请提供{slot_key}信息"
batch_result.add_result(BackfillResult(
status=BackfillStatus.EXTRACTION_FAILED,
slot_key=slot_key,
error_message="无法从回复中提取",
ask_back_prompt=ask_back_prompt,
))
continue
source = self._get_source_for_strategy(extracted_result.final_strategy)
confidence = self._get_confidence_for_strategy(source)
result = await self.backfill_single_slot(
slot_key=slot_key,
candidate_value=extracted_result.final_value,
source=source,
confidence=confidence,
)
batch_result.add_result(result)
return batch_result
async def _extract_value(
self,
slot_key: str,
user_input: str,
strategies: list[str],
slot_type: str = "string",
validation_rule: str | None = None,
) -> StrategyChainResult:
"""
执行槽位值提取
Args:
slot_key: 槽位键名
user_input: 用户输入
strategies: 提取策略链
slot_type: 槽位类型
validation_rule: 校验规则
Returns:
StrategyChainResult: 提取结果
"""
context = ExtractContext(
tenant_id=self._tenant_id,
slot_key=slot_key,
user_input=user_input,
slot_type=slot_type,
validation_rule=validation_rule,
)
return await self._strategy_executor.execute_chain(
strategies=strategies,
context=context,
)
def _get_source_for_strategy(self, strategy: str | None) -> str:
"""根据策略获取来源"""
strategy_source_map = {
"rule": SlotSource.RULE_EXTRACTED.value,
"llm": SlotSource.LLM_INFERRED.value,
"user_input": SlotSource.USER_CONFIRMED.value,
}
return strategy_source_map.get(strategy or "", "unknown")
def _get_confidence_for_strategy(self, source: str) -> float:
"""根据来源获取置信度"""
confidence_map = {
SlotSource.USER_CONFIRMED.value: 1.0,
SlotSource.RULE_EXTRACTED.value: 0.9,
SlotSource.LLM_INFERRED.value: 0.7,
"context": 0.5,
SlotSource.DEFAULT.value: 0.3,
}
return confidence_map.get(source, 0.5)
def _generate_confirmation_prompt(
self,
slot_key: str,
value: Any,
) -> str:
"""生成确认提示"""
return f"我理解您说的是「{value}」,对吗?"
async def confirm_low_confidence_slot(
self,
slot_key: str,
confirmed: bool,
) -> BackfillResult:
"""
确认低置信度槽位
Args:
slot_key: 槽位键名
confirmed: 用户是否确认
Returns:
BackfillResult: 确认结果
"""
if not self._session_id:
return BackfillResult(
status=BackfillStatus.SUCCESS,
slot_key=slot_key,
)
aggregator = await self._get_state_aggregator()
if confirmed:
updated_state = await aggregator.update_slot(
slot_key=slot_key,
source=SlotSource.USER_CONFIRMED.value,
confidence=1.0,
)
return BackfillResult(
status=BackfillStatus.SUCCESS,
slot_key=slot_key,
source=SlotSource.USER_CONFIRMED.value,
confidence=1.0,
updated_state=updated_state,
)
else:
await aggregator.clear_slot(slot_key)
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
self._tenant_id, slot_key
)
return BackfillResult(
status=BackfillStatus.VALIDATION_FAILED,
slot_key=slot_key,
ask_back_prompt=ask_back_prompt or f"请重新提供{slot_key}信息",
)
def create_slot_backfill_service(
session: AsyncSession,
tenant_id: str,
session_id: str | None = None,
) -> SlotBackfillService:
"""
创建槽位回填服务实例
Args:
session: 数据库会话
tenant_id: 租户 ID
session_id: 会话 ID
Returns:
SlotBackfillService: 槽位回填服务实例
"""
return SlotBackfillService(
session=session,
tenant_id=tenant_id,
session_id=session_id,
)

View File

@ -0,0 +1,368 @@
"""
Slot Extraction Integration Service.
槽位提取集成服务 - 将自动提取能力接入主链路
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成
职责
1. 接入点memory_recall 之后KB 检索之前
2. 执行策略链rule -> llm -> user_input
3. 抽取结果统一走 SlotManager 校验
4. 提供 traceextracted_slotsvalidation_pass/failask_back_triggered
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_backfill_service import (
BackfillResult,
BackfillStatus,
SlotBackfillService,
)
from app.services.mid.slot_manager import SlotManager
from app.services.mid.slot_state_aggregator import SlotState, SlotStateAggregator
from app.services.mid.slot_strategy_executor import (
ExtractContext,
SlotStrategyExecutor,
StrategyChainResult,
)
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
@dataclass
class ExtractionTrace:
"""
提取追踪信息
Attributes:
slot_key: 槽位键名
strategy: 使用的策略
extracted_value: 提取的值
validation_passed: 校验是否通过
final_value: 最终值校验后
execution_time_ms: 执行时间
failure_reason: 失败原因
"""
slot_key: str
strategy: str | None = None
extracted_value: Any = None
validation_passed: bool = False
final_value: Any = None
execution_time_ms: float = 0.0
failure_reason: str | None = None
def to_dict(self) -> dict[str, Any]:
return {
"slot_key": self.slot_key,
"strategy": self.strategy,
"extracted_value": self.extracted_value,
"validation_passed": self.validation_passed,
"final_value": self.final_value,
"execution_time_ms": self.execution_time_ms,
"failure_reason": self.failure_reason,
}
@dataclass
class ExtractionResult:
"""
提取结果
Attributes:
success: 是否成功
extracted_slots: 成功提取的槽位
failed_slots: 提取失败的槽位
traces: 提取追踪信息列表
total_execution_time_ms: 总执行时间
ask_back_triggered: 是否触发追问
ask_back_prompts: 追问提示列表
"""
success: bool = False
extracted_slots: dict[str, Any] = field(default_factory=dict)
failed_slots: list[str] = field(default_factory=list)
traces: list[ExtractionTrace] = field(default_factory=list)
total_execution_time_ms: float = 0.0
ask_back_triggered: bool = False
ask_back_prompts: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"success": self.success,
"extracted_slots": self.extracted_slots,
"failed_slots": self.failed_slots,
"traces": [t.to_dict() for t in self.traces],
"total_execution_time_ms": self.total_execution_time_ms,
"ask_back_triggered": self.ask_back_triggered,
"ask_back_prompts": self.ask_back_prompts,
}
class SlotExtractionIntegration:
"""
[AC-MRS-SLOT-EXTRACT-01] 槽位提取集成服务
将自动提取能力接入主链路
- 接入点memory_recall 之后KB 检索之前
- 执行策略链rule -> llm -> user_input
- 抽取结果统一走 SlotManager 校验
- 提供 trace
"""
DEFAULT_STRATEGIES = ["rule", "llm"]
def __init__(
self,
session: AsyncSession,
tenant_id: str,
session_id: str | None = None,
slot_manager: SlotManager | None = None,
strategy_executor: SlotStrategyExecutor | None = None,
):
self._session = session
self._tenant_id = tenant_id
self._session_id = session_id
self._slot_manager = slot_manager or SlotManager(session=session)
self._strategy_executor = strategy_executor or SlotStrategyExecutor()
self._slot_def_service = SlotDefinitionService(session)
self._backfill_service: SlotBackfillService | None = None
async def _get_backfill_service(self) -> SlotBackfillService:
"""获取回填服务"""
if self._backfill_service is None:
self._backfill_service = SlotBackfillService(
session=self._session,
tenant_id=self._tenant_id,
session_id=self._session_id,
slot_manager=self._slot_manager,
strategy_executor=self._strategy_executor,
)
return self._backfill_service
async def extract_and_fill(
self,
user_input: str,
target_slots: list[str] | None = None,
strategies: list[str] | None = None,
slot_state: SlotState | None = None,
) -> ExtractionResult:
"""
执行提取并填充槽位
Args:
user_input: 用户输入
target_slots: 目标槽位列表为空则提取所有必填槽位
strategies: 提取策略链默认 rule -> llm
slot_state: 当前槽位状态用于识别缺失槽位
Returns:
ExtractionResult: 提取结果
"""
start_time = time.time()
strategies = strategies or self.DEFAULT_STRATEGIES
if target_slots is None:
target_slots = await self._get_missing_required_slots(slot_state)
if not target_slots:
return ExtractionResult(success=True)
result = ExtractionResult()
for slot_key in target_slots:
trace = await self._extract_single_slot(
slot_key=slot_key,
user_input=user_input,
strategies=strategies,
)
result.traces.append(trace)
if trace.validation_passed and trace.final_value is not None:
result.extracted_slots[slot_key] = trace.final_value
else:
result.failed_slots.append(slot_key)
if trace.failure_reason:
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
self._tenant_id, slot_key
)
if ask_back_prompt:
result.ask_back_prompts.append(ask_back_prompt)
result.total_execution_time_ms = (time.time() - start_time) * 1000
result.success = len(result.extracted_slots) > 0
result.ask_back_triggered = len(result.ask_back_prompts) > 0
if result.extracted_slots and self._session_id:
await self._save_extracted_slots(result.extracted_slots)
logger.info(
f"[AC-MRS-SLOT-EXTRACT-01] Extraction completed: "
f"tenant={self._tenant_id}, extracted={len(result.extracted_slots)}, "
f"failed={len(result.failed_slots)}, time_ms={result.total_execution_time_ms:.2f}"
)
return result
async def _extract_single_slot(
self,
slot_key: str,
user_input: str,
strategies: list[str],
) -> ExtractionTrace:
"""提取单个槽位"""
start_time = time.time()
trace = ExtractionTrace(slot_key=slot_key)
slot_def = await self._slot_def_service.get_slot_definition_by_key(
self._tenant_id, slot_key
)
if not slot_def:
trace.failure_reason = "Slot definition not found"
trace.execution_time_ms = (time.time() - start_time) * 1000
return trace
context = ExtractContext(
tenant_id=self._tenant_id,
slot_key=slot_key,
user_input=user_input,
slot_type=slot_def.type,
validation_rule=slot_def.validation_rule,
session_id=self._session_id,
)
chain_result = await self._strategy_executor.execute_chain(
strategies=strategies,
context=context,
ask_back_prompt=slot_def.ask_back_prompt,
)
trace.strategy = chain_result.final_strategy
trace.extracted_value = chain_result.final_value
trace.execution_time_ms = (time.time() - start_time) * 1000
if not chain_result.success:
trace.failure_reason = "Extraction failed"
if chain_result.steps:
last_step = chain_result.steps[-1]
trace.failure_reason = last_step.failure_reason
return trace
backfill_service = await self._get_backfill_service()
source = self._get_source_for_strategy(chain_result.final_strategy)
backfill_result = await backfill_service.backfill_single_slot(
slot_key=slot_key,
candidate_value=chain_result.final_value,
source=source,
confidence=self._get_confidence_for_source(source),
)
trace.validation_passed = backfill_result.is_success()
trace.final_value = backfill_result.normalized_value
if not backfill_result.is_success():
trace.failure_reason = backfill_result.error_message or "Validation failed"
return trace
async def _get_missing_required_slots(
self,
slot_state: SlotState | None,
) -> list[str]:
"""获取缺失的必填槽位"""
if slot_state and slot_state.missing_required_slots:
return [
s.get("slot_key")
for s in slot_state.missing_required_slots
if s.get("slot_key")
]
required_defs = await self._slot_def_service.list_slot_definitions(
tenant_id=self._tenant_id,
required=True,
)
return [d.slot_key for d in required_defs]
async def _save_extracted_slots(
self,
extracted_slots: dict[str, Any],
) -> None:
"""保存提取的槽位到缓存"""
if not self._session_id:
return
aggregator = SlotStateAggregator(
session=self._session,
tenant_id=self._tenant_id,
session_id=self._session_id,
)
for slot_key, value in extracted_slots.items():
await aggregator.update_slot(
slot_key=slot_key,
value=value,
source=SlotSource.RULE_EXTRACTED.value,
confidence=0.9,
)
def _get_source_for_strategy(self, strategy: str | None) -> str:
"""根据策略获取来源"""
strategy_source_map = {
"rule": SlotSource.RULE_EXTRACTED.value,
"llm": SlotSource.LLM_INFERRED.value,
"user_input": SlotSource.USER_CONFIRMED.value,
}
return strategy_source_map.get(strategy or "", "unknown")
def _get_confidence_for_source(self, source: str) -> float:
"""根据来源获取置信度"""
confidence_map = {
SlotSource.USER_CONFIRMED.value: 1.0,
SlotSource.RULE_EXTRACTED.value: 0.9,
SlotSource.LLM_INFERRED.value: 0.7,
}
return confidence_map.get(source, 0.5)
async def integrate_slot_extraction(
session: AsyncSession,
tenant_id: str,
session_id: str,
user_input: str,
slot_state: SlotState | None = None,
strategies: list[str] | None = None,
) -> ExtractionResult:
"""
便捷函数集成槽位提取
Args:
session: 数据库会话
tenant_id: 租户 ID
session_id: 会话 ID
user_input: 用户输入
slot_state: 当前槽位状态
strategies: 提取策略链
Returns:
ExtractionResult: 提取结果
"""
integration = SlotExtractionIntegration(
session=session,
tenant_id=tenant_id,
session_id=session_id,
)
return await integration.extract_and_fill(
user_input=user_input,
slot_state=slot_state,
strategies=strategies,
)

View File

@ -0,0 +1,379 @@
"""
Slot Manager Service.
槽位管理服务 - 统一槽位写入入口集成校验逻辑
职责
1. 在槽位值写入前执行校验
2. 管理槽位值的来源和置信度
3. 提供槽位写入的统一接口
4. 返回校验失败时的追问提示
"""
import logging
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import SlotDefinition
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_validation_service import (
BatchValidationResult,
SlotValidationError,
SlotValidationService,
)
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
class SlotWriteResult:
"""
槽位写入结果
Attributes:
success: 是否成功校验通过并写入
slot_key: 槽位键名
value: 最终写入的值
error: 校验错误信息校验失败时
ask_back_prompt: 追问提示语校验失败时
"""
def __init__(
self,
success: bool,
slot_key: str,
value: Any | None = None,
error: SlotValidationError | None = None,
ask_back_prompt: str | None = None,
):
self.success = success
self.slot_key = slot_key
self.value = value
self.error = error
self.ask_back_prompt = ask_back_prompt
def to_dict(self) -> dict[str, Any]:
"""转换为字典"""
result = {
"success": self.success,
"slot_key": self.slot_key,
"value": self.value,
}
if self.error:
result["error"] = {
"error_code": self.error.error_code,
"error_message": self.error.error_message,
}
if self.ask_back_prompt:
result["ask_back_prompt"] = self.ask_back_prompt
return result
class SlotManager:
"""
槽位管理器
统一槽位写入入口在写入前执行校验
支持从 SlotDefinition 加载校验规则并执行
"""
def __init__(
self,
session: AsyncSession | None = None,
validation_service: SlotValidationService | None = None,
slot_def_service: SlotDefinitionService | None = None,
):
"""
初始化槽位管理器
Args:
session: 数据库会话
validation_service: 校验服务实例
slot_def_service: 槽位定义服务实例
"""
self._session = session
self._validation_service = validation_service or SlotValidationService()
self._slot_def_service = slot_def_service
self._slot_def_cache: dict[str, SlotDefinition | None] = {}
async def write_slot(
self,
tenant_id: str,
slot_key: str,
value: Any,
source: SlotSource = SlotSource.USER_CONFIRMED,
confidence: float = 1.0,
skip_validation: bool = False,
) -> SlotWriteResult:
"""
写入单个槽位值带校验
执行流程
1. 加载槽位定义
2. 执行校验如果未跳过
3. 返回校验结果
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
value: 槽位值
source: 值来源
confidence: 置信度
skip_validation: 是否跳过校验用于特殊场景
Returns:
SlotWriteResult: 写入结果
"""
# 加载槽位定义
slot_def = await self._get_slot_definition(tenant_id, slot_key)
# 如果没有定义且非跳过校验,允许写入(动态槽位)
if slot_def is None and skip_validation:
logger.debug(
f"[SlotManager] Writing slot without definition: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
if slot_def is None:
# 未定义槽位,允许写入但记录日志
logger.info(
f"[SlotManager] Slot definition not found, allowing write: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
# 执行校验
if not skip_validation:
validation_result = self._validation_service.validate_slot_value(
slot_def, value, tenant_id
)
if not validation_result.ok:
logger.info(
f"[SlotManager] Slot validation failed: "
f"tenant_id={tenant_id}, slot_key={slot_key}, "
f"error_code={validation_result.error_code}"
)
return SlotWriteResult(
success=False,
slot_key=slot_key,
error=SlotValidationError(
slot_key=slot_key,
error_code=validation_result.error_code or "VALIDATION_FAILED",
error_message=validation_result.error_message or "校验失败",
ask_back_prompt=validation_result.ask_back_prompt,
),
ask_back_prompt=validation_result.ask_back_prompt,
)
# 使用归一化后的值
value = validation_result.normalized_value
logger.debug(
f"[SlotManager] Slot validation passed: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
async def write_slots(
self,
tenant_id: str,
values: dict[str, Any],
source: SlotSource = SlotSource.USER_CONFIRMED,
confidence: float = 1.0,
skip_validation: bool = False,
) -> BatchValidationResult:
"""
批量写入槽位值带校验
Args:
tenant_id: 租户 ID
values: 槽位值字典 {slot_key: value}
source: 值来源
confidence: 置信度
skip_validation: 是否跳过校验
Returns:
BatchValidationResult: 批量校验结果
"""
if skip_validation:
return BatchValidationResult(
ok=True,
validated_values=values,
)
# 加载所有相关槽位定义
slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys()))
# 执行批量校验
result = self._validation_service.validate_slots(
slot_defs, values, tenant_id
)
if not result.ok:
logger.info(
f"[SlotManager] Batch slot validation failed: "
f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}"
)
else:
logger.debug(
f"[SlotManager] Batch slot validation passed: "
f"tenant_id={tenant_id}, slots={list(values.keys())}"
)
return result
async def validate_before_write(
self,
tenant_id: str,
slot_key: str,
value: Any,
) -> tuple[bool, SlotValidationError | None]:
"""
在写入前预校验槽位值
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
value: 槽位值
Returns:
Tuple of (是否通过, 错误信息)
"""
slot_def = await self._get_slot_definition(tenant_id, slot_key)
if slot_def is None:
# 未定义槽位,视为通过
return True, None
result = self._validation_service.validate_slot_value(
slot_def, value, tenant_id
)
if result.ok:
return True, None
return False, SlotValidationError(
slot_key=slot_key,
error_code=result.error_code or "VALIDATION_FAILED",
error_message=result.error_message or "校验失败",
ask_back_prompt=result.ask_back_prompt,
)
async def get_ask_back_prompt(
self,
tenant_id: str,
slot_key: str,
) -> str | None:
"""
获取槽位的追问提示语
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
追问提示语或 None
"""
slot_def = await self._get_slot_definition(tenant_id, slot_key)
if slot_def is None:
return None
if isinstance(slot_def, SlotDefinition):
return slot_def.ask_back_prompt
return slot_def.get("ask_back_prompt")
async def _get_slot_definition(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | dict[str, Any] | None:
"""
获取槽位定义带缓存
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
槽位定义或 None
"""
cache_key = f"{tenant_id}:{slot_key}"
if cache_key in self._slot_def_cache:
return self._slot_def_cache[cache_key]
slot_def = None
if self._slot_def_service:
slot_def = await self._slot_def_service.get_slot_definition_by_key(
tenant_id, slot_key
)
elif self._session:
service = SlotDefinitionService(self._session)
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
self._slot_def_cache[cache_key] = slot_def
return slot_def
async def _get_slot_definitions(
self,
tenant_id: str,
slot_keys: list[str],
) -> list[SlotDefinition | dict[str, Any]]:
"""
批量获取槽位定义
Args:
tenant_id: 租户 ID
slot_keys: 槽位键名列表
Returns:
槽位定义列表
"""
slot_defs = []
for key in slot_keys:
slot_def = await self._get_slot_definition(tenant_id, key)
if slot_def:
slot_defs.append(slot_def)
return slot_defs
def clear_cache(self) -> None:
"""清除槽位定义缓存"""
self._slot_def_cache.clear()
def create_slot_manager(
session: AsyncSession | None = None,
validation_service: SlotValidationService | None = None,
slot_def_service: SlotDefinitionService | None = None,
) -> SlotManager:
"""
创建槽位管理器实例
Args:
session: 数据库会话
validation_service: 校验服务实例
slot_def_service: 槽位定义服务实例
Returns:
SlotManager: 槽位管理器实例
"""
return SlotManager(
session=session,
validation_service=validation_service,
slot_def_service=slot_def_service,
)

View File

@ -0,0 +1,562 @@
"""
Slot State Aggregator Service.
槽位状态聚合服务 - 统一维护本轮槽位状态
职责
1. 聚合来自 memory_recall 的槽位值
2. 叠加本轮输入的槽位值
3. 识别缺失的必填槽位
4. 支持槽位与元数据字段的关联映射
5. KB 检索过滤提供统一的槽位值来源
6. [AC-MRS-SLOT-CACHE-01] 多轮状态持久化
[AC-MRS-SLOT-META-01] 槽位与元数据关联机制
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.schemas import MemorySlot, SlotSource
from app.services.cache.slot_state_cache import (
CachedSlotState,
CachedSlotValue,
get_slot_state_cache,
)
from app.services.mid.slot_manager import SlotManager
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
@dataclass
class SlotState:
"""
槽位状态聚合结果
Attributes:
filled_slots: 已填充的槽位值字典 {slot_key: value}
missing_required_slots: 缺失的必填槽位列表
slot_sources: 槽位值来源字典 {slot_key: source}
slot_confidence: 槽位置信度字典 {slot_key: confidence}
slot_to_field_map: 槽位到元数据字段的映射 {slot_key: field_key}
"""
filled_slots: dict[str, Any] = field(default_factory=dict)
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
slot_sources: dict[str, str] = field(default_factory=dict)
slot_confidence: dict[str, float] = field(default_factory=dict)
slot_to_field_map: dict[str, str] = field(default_factory=dict)
def get_value_for_filter(self, field_key: str) -> Any:
"""
获取用于 KB 过滤的字段值
优先从 slot_to_field_map 反向查找
"""
# 直接匹配
if field_key in self.filled_slots:
return self.filled_slots[field_key]
# 通过 slot_to_field_map 反向查找
for slot_key, mapped_field_key in self.slot_to_field_map.items():
if mapped_field_key == field_key and slot_key in self.filled_slots:
return self.filled_slots[slot_key]
return None
def to_debug_info(self) -> dict[str, Any]:
"""转换为调试信息字典"""
return {
"filled_slots": self.filled_slots,
"missing_required_slots": self.missing_required_slots,
"slot_sources": self.slot_sources,
"slot_to_field_map": self.slot_to_field_map,
}
class SlotStateAggregator:
"""
[AC-MRS-SLOT-META-01] 槽位状态聚合器
统一维护本轮槽位状态支持
- memory_recall 初始化槽位
- 叠加本轮输入的槽位值
- 识别缺失的必填槽位
- 建立槽位与元数据字段的关联
- [AC-MRS-SLOT-CACHE-01] 多轮状态持久化
"""
def __init__(
self,
session: AsyncSession,
tenant_id: str,
slot_manager: SlotManager | None = None,
session_id: str | None = None,
):
self._session = session
self._tenant_id = tenant_id
self._session_id = session_id
self._slot_manager = slot_manager or SlotManager(session=session)
self._slot_def_service = SlotDefinitionService(session)
self._cache = get_slot_state_cache()
async def aggregate(
self,
memory_slots: dict[str, MemorySlot] | None = None,
current_input_slots: dict[str, Any] | None = None,
context: dict[str, Any] | None = None,
use_cache: bool = True,
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
) -> SlotState:
"""
聚合槽位状态
执行流程
1. [AC-MRS-SLOT-CACHE-01] 从缓存加载已有状态
2. memory_slots 初始化已填充槽位
3. 叠加 current_input_slots优先级更高
4. context 提取槽位值
5. 识别缺失的必填槽位
6. 建立槽位与元数据字段的关联映射
7. [AC-MRS-SLOT-CACHE-01] 回写缓存
Args:
memory_slots: memory_recall 召回的槽位
current_input_slots: 本轮输入的槽位值
context: 上下文信息可能包含槽位值
use_cache: 是否使用缓存默认 True
Returns:
SlotState: 聚合后的槽位状态
"""
state = SlotState()
# [AC-MRS-SLOT-CACHE-01] 1. 从缓存加载已有状态
cached_state = None
if use_cache and self._session_id:
cached_state = await self._cache.get(self._tenant_id, self._session_id)
if cached_state:
for slot_key, cached_value in cached_state.filled_slots.items():
state.filled_slots[slot_key] = cached_value.value
state.slot_sources[slot_key] = cached_value.source
state.slot_confidence[slot_key] = cached_value.confidence
state.slot_to_field_map = cached_state.slot_to_field_map.copy()
logger.info(
f"[AC-MRS-SLOT-CACHE-01] Loaded from cache: "
f"tenant={self._tenant_id}, session={self._session_id}, "
f"slots={list(state.filled_slots.keys())}"
)
# 2. 从 memory_slots 初始化
if memory_slots:
for slot_key, memory_slot in memory_slots.items():
state.filled_slots[slot_key] = memory_slot.value
state.slot_sources[slot_key] = memory_slot.source.value
state.slot_confidence[slot_key] = memory_slot.confidence
logger.info(
f"[AC-MRS-SLOT-META-01] Initialized from memory: "
f"tenant={self._tenant_id}, slots={list(memory_slots.keys())}"
)
# 3. 叠加本轮输入(优先级更高)
if current_input_slots:
for slot_key, value in current_input_slots.items():
if value is not None:
state.filled_slots[slot_key] = value
state.slot_sources[slot_key] = SlotSource.USER_CONFIRMED.value
state.slot_confidence[slot_key] = 1.0
logger.info(
f"[AC-MRS-SLOT-META-01] Merged current input: "
f"tenant={self._tenant_id}, slots={list(current_input_slots.keys())}"
)
# 4. 从 context 提取槽位值(优先级最低)
if context:
context_slots = self._extract_slots_from_context(context)
for slot_key, value in context_slots.items():
if slot_key not in state.filled_slots and value is not None:
state.filled_slots[slot_key] = value
state.slot_sources[slot_key] = "context"
state.slot_confidence[slot_key] = 0.5
# 5. 加载槽位定义并建立关联
await self._build_slot_mappings(state)
# 6. 识别缺失的必填槽位
await self._identify_missing_required_slots(state, scene_slot_context)
# [AC-MRS-SLOT-CACHE-01] 7. 回写缓存
if use_cache and self._session_id:
await self._save_to_cache(state)
logger.info(
f"[AC-MRS-SLOT-META-01] Slot state aggregated: "
f"tenant={self._tenant_id}, filled={len(state.filled_slots)}, "
f"missing={len(state.missing_required_slots)}"
)
return state
async def _save_to_cache(self, state: SlotState) -> None:
"""
[AC-MRS-SLOT-CACHE-01] 保存槽位状态到缓存
"""
if not self._session_id:
return
cached_slots = {}
for slot_key, value in state.filled_slots.items():
source = state.slot_sources.get(slot_key, "unknown")
confidence = state.slot_confidence.get(slot_key, 1.0)
cached_slots[slot_key] = CachedSlotValue(
value=value,
source=source,
confidence=confidence,
)
cached_state = CachedSlotState(
filled_slots=cached_slots,
slot_to_field_map=state.slot_to_field_map.copy(),
)
await self._cache.set(self._tenant_id, self._session_id, cached_state)
logger.debug(
f"[AC-MRS-SLOT-CACHE-01] Saved to cache: "
f"tenant={self._tenant_id}, session={self._session_id}"
)
async def update_slot(
self,
slot_key: str,
value: Any,
source: str = "user_confirmed",
confidence: float = 1.0,
) -> SlotState | None:
"""
[AC-MRS-SLOT-CACHE-01] 更新单个槽位值并保存到缓存
Args:
slot_key: 槽位键名
value: 槽位值
source: 值来源
confidence: 置信度
Returns:
更新后的槽位状态如果没有 session_id 则返回 None
"""
if not self._session_id:
return None
cached_value = CachedSlotValue(
value=value,
source=source,
confidence=confidence,
)
cached_state = await self._cache.merge_and_set(
tenant_id=self._tenant_id,
session_id=self._session_id,
new_slots={slot_key: cached_value},
)
state = SlotState()
state.filled_slots = cached_state.get_simple_filled_slots()
state.slot_sources = cached_state.get_slot_sources()
state.slot_confidence = cached_state.get_slot_confidence()
state.slot_to_field_map = cached_state.slot_to_field_map.copy()
await self._identify_missing_required_slots(state)
return state
async def clear_slot(self, slot_key: str) -> bool:
"""
[AC-MRS-SLOT-CACHE-01] 清除单个槽位值
Args:
slot_key: 槽位键名
Returns:
是否成功
"""
if not self._session_id:
return False
return await self._cache.clear_slot(
tenant_id=self._tenant_id,
session_id=self._session_id,
slot_key=slot_key,
)
async def clear_all_slots(self) -> bool:
"""
[AC-MRS-SLOT-CACHE-01] 清除所有槽位状态
Returns:
是否成功
"""
if not self._session_id:
return False
return await self._cache.delete(
tenant_id=self._tenant_id,
session_id=self._session_id,
)
def _extract_slots_from_context(self, context: dict[str, Any]) -> dict[str, Any]:
"""
从上下文中提取可能的槽位值
常见的槽位值可能存在于
- scene
- product_line
- region
- grade
等字段
"""
slots = {}
slot_candidates = [
"scene", "product_line", "region", "grade",
"category", "type", "status", "priority"
]
for key in slot_candidates:
if key in context and context[key] is not None:
slots[key] = context[key]
return slots
async def _build_slot_mappings(self, state: SlotState) -> None:
"""
建立槽位与元数据字段的关联映射
通过 linked_field_id 关联 SlotDefinition MetadataFieldDefinition
"""
try:
# 获取所有槽位定义
slot_defs = await self._slot_def_service.list_slot_definitions(
tenant_id=self._tenant_id
)
for slot_def in slot_defs:
if slot_def.linked_field_id:
# 获取关联的元数据字段
from app.services.metadata_field_definition_service import (
MetadataFieldDefinitionService
)
field_service = MetadataFieldDefinitionService(self._session)
linked_field = await field_service.get_field_definition(
tenant_id=self._tenant_id,
field_id=str(slot_def.linked_field_id)
)
if linked_field:
state.slot_to_field_map[slot_def.slot_key] = linked_field.field_key
# 检查类型一致性并告警
await self._check_type_consistency(
slot_def, linked_field, state
)
except Exception as e:
logger.warning(
f"[AC-MRS-SLOT-META-01] Failed to build slot mappings: {e}"
)
async def _check_type_consistency(
self,
slot_def: Any,
linked_field: Any,
state: SlotState,
) -> None:
"""
检查槽位与关联元数据字段的类型一致性
当不一致时记录告警日志不强拦截
"""
# 检查类型一致性
if slot_def.type != linked_field.type:
logger.warning(
f"[AC-MRS-SLOT-META-01] Type mismatch: "
f"slot='{slot_def.slot_key}' type={slot_def.type} vs "
f"field='{linked_field.field_key}' type={linked_field.type}"
)
# 检查 required 一致性
if slot_def.required != linked_field.required:
logger.warning(
f"[AC-MRS-SLOT-META-01] Required mismatch: "
f"slot='{slot_def.slot_key}' required={slot_def.required} vs "
f"field='{linked_field.field_key}' required={linked_field.required}"
)
# 检查 options 一致性(对于 enum/array_enum 类型)
if slot_def.type in ["enum", "array_enum"]:
slot_options = set(slot_def.options or [])
field_options = set(linked_field.options or [])
if slot_options != field_options:
logger.warning(
f"[AC-MRS-SLOT-META-01] Options mismatch: "
f"slot='{slot_def.slot_key}' options={slot_options} vs "
f"field='{linked_field.field_key}' options={field_options}"
)
async def _identify_missing_required_slots(
self,
state: SlotState,
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
) -> None:
"""
识别缺失的必填槽位
基于 SlotDefinition required 字段和 linked_field required 字段
[AC-SCENE-SLOT-02] 当有场景槽位上下文时优先使用场景定义的必填槽位
"""
try:
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景定义的必填槽位
if scene_slot_context:
scene_required_keys = set(scene_slot_context.get_required_slot_keys())
logger.info(
f"[AC-SCENE-SLOT-02] Using scene required slots: "
f"scene={scene_slot_context.scene_key}, "
f"required_keys={scene_required_keys}"
)
# 获取场景中定义的所有槽位
all_slot_defs = await self._slot_def_service.list_slot_definitions(
tenant_id=self._tenant_id,
)
slot_def_map = {slot.slot_key: slot for slot in all_slot_defs}
for slot_key in scene_required_keys:
if slot_key not in state.filled_slots:
slot_def = slot_def_map.get(slot_key)
ask_back_prompt = slot_def.ask_back_prompt if slot_def else None
if not ask_back_prompt:
ask_back_prompt = f"请提供{slot_key}信息"
missing_info = {
"slot_key": slot_key,
"label": slot_key,
"reason": "scene_required_slot_missing",
"ask_back_prompt": ask_back_prompt,
"scene": scene_slot_context.scene_key,
}
if slot_def and slot_def.linked_field_id:
missing_info["linked_field_id"] = str(slot_def.linked_field_id)
state.missing_required_slots.append(missing_info)
return
# 获取所有 required 的槽位定义
required_slot_defs = await self._slot_def_service.list_slot_definitions(
tenant_id=self._tenant_id,
required=True,
)
for slot_def in required_slot_defs:
if slot_def.slot_key not in state.filled_slots:
# 获取追问提示
ask_back_prompt = slot_def.ask_back_prompt
# 如果没有配置追问提示,使用通用模板
if not ask_back_prompt:
ask_back_prompt = f"请提供{slot_def.slot_key}信息"
missing_info = {
"slot_key": slot_def.slot_key,
"label": slot_def.slot_key,
"reason": "required_slot_missing",
"ask_back_prompt": ask_back_prompt,
}
# 如果有关联字段,使用字段的 label
if slot_def.linked_field_id:
from app.services.metadata_field_definition_service import (
MetadataFieldDefinitionService
)
field_service = MetadataFieldDefinitionService(self._session)
linked_field = await field_service.get_field_definition(
tenant_id=self._tenant_id,
field_id=str(slot_def.linked_field_id)
)
if linked_field:
missing_info["label"] = linked_field.label
missing_info["field_key"] = linked_field.field_key
state.missing_required_slots.append(missing_info)
logger.info(
f"[AC-MRS-SLOT-META-01] Missing required slot: "
f"slot_key={slot_def.slot_key}"
)
except Exception as e:
logger.warning(
f"[AC-MRS-SLOT-META-01] Failed to identify missing slots: {e}"
)
async def generate_ask_back_response(
self,
state: SlotState,
missing_slot_key: str | None = None,
) -> str | None:
"""
生成追问响应文案
Args:
state: 当前槽位状态
missing_slot_key: 指定要追问的槽位键名 None 时追问第一个缺失槽位
Returns:
追问文案或 None如果没有缺失槽位
"""
if not state.missing_required_slots:
return None
missing_info = None
# 如果指定了槽位键名,查找对应的追问提示
if missing_slot_key:
for missing in state.missing_required_slots:
if missing.get("slot_key") == missing_slot_key:
missing_info = missing
break
else:
# 使用第一个缺失槽位
missing_info = state.missing_required_slots[0]
if missing_info is None:
return None
# 优先使用配置的 ask_back_prompt
ask_back_prompt = missing_info.get("ask_back_prompt")
if ask_back_prompt:
return ask_back_prompt
# 使用通用模板
label = missing_info.get("label", missing_info.get("slot_key", "相关信息"))
return f"为了更好地为您提供帮助,请告诉我您的{label}"
def create_slot_state_aggregator(
session: AsyncSession,
tenant_id: str,
) -> SlotStateAggregator:
"""
创建槽位状态聚合器实例
Args:
session: 数据库会话
tenant_id: 租户 ID
Returns:
SlotStateAggregator: 槽位状态聚合器实例
"""
return SlotStateAggregator(
session=session,
tenant_id=tenant_id,
)

View File

@ -0,0 +1,355 @@
"""
Slot Strategy Executor.
[AC-MRS-07-UPGRADE] 槽位提取策略链执行器
按顺序执行提取策略链直到成功提取并通过校验
支持失败分类和详细日志追踪
"""
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable
from app.models.entities import ExtractFailureType
logger = logging.getLogger(__name__)
class ExtractStrategyType(str, Enum):
"""提取策略类型"""
RULE = "rule"
LLM = "llm"
USER_INPUT = "user_input"
@dataclass
class StrategyStepResult:
"""单步策略执行结果"""
strategy: str
success: bool
value: Any = None
failure_type: ExtractFailureType | None = None
failure_reason: str = ""
execution_time_ms: float = 0.0
@dataclass
class StrategyChainResult:
"""策略链执行结果"""
slot_key: str
success: bool
final_value: Any = None
final_strategy: str | None = None
steps: list[StrategyStepResult] = field(default_factory=list)
total_execution_time_ms: float = 0.0
ask_back_prompt: str | None = None
def to_dict(self) -> dict[str, Any]:
"""转换为字典"""
return {
"slot_key": self.slot_key,
"success": self.success,
"final_value": self.final_value,
"final_strategy": self.final_strategy,
"steps": [
{
"strategy": step.strategy,
"success": step.success,
"value": step.value if step.success else None,
"failure_type": step.failure_type.value if step.failure_type else None,
"failure_reason": step.failure_reason,
"execution_time_ms": step.execution_time_ms,
}
for step in self.steps
],
"total_execution_time_ms": self.total_execution_time_ms,
"ask_back_prompt": self.ask_back_prompt,
}
@dataclass
class ExtractContext:
"""提取上下文"""
tenant_id: str
slot_key: str
user_input: str
slot_type: str
validation_rule: str | None = None
history: list[dict[str, str]] | None = None
session_id: str | None = None
class SlotStrategyExecutor:
"""
[AC-MRS-07-UPGRADE] 槽位提取策略链执行器
职责
1. 按策略链顺序执行提取
2. 某一步成功且校验通过 -> 停止并返回结果
3. 当前策略失败 -> 记录失败原因继续下一策略
4. 全部失败 -> 返回结构化失败结果
5. 提供可追踪日志slot_keystrategyreason
"""
def __init__(
self,
rule_extractor: Callable[[ExtractContext], Any] | None = None,
llm_extractor: Callable[[ExtractContext], Any] | None = None,
user_input_extractor: Callable[[ExtractContext], Any] | None = None,
):
"""
初始化执行器
Args:
rule_extractor: 规则提取器函数
llm_extractor: LLM提取器函数
user_input_extractor: 用户输入提取器函数
"""
self._extractors: dict[str, Callable[[ExtractContext], Any]] = {
ExtractStrategyType.RULE.value: rule_extractor or self._default_rule_extract,
ExtractStrategyType.LLM.value: llm_extractor or self._default_llm_extract,
ExtractStrategyType.USER_INPUT.value: user_input_extractor or self._default_user_input_extract,
}
async def execute_chain(
self,
strategies: list[str],
context: ExtractContext,
ask_back_prompt: str | None = None,
) -> StrategyChainResult:
"""
执行提取策略链
Args:
strategies: 策略链 ["user_input", "rule", "llm"]
context: 提取上下文
ask_back_prompt: 追问提示语全部失败时使用
Returns:
StrategyChainResult: 执行结果
"""
import time
start_time = time.time()
steps: list[StrategyStepResult] = []
logger.info(
f"[SlotStrategyExecutor] Starting strategy chain for slot '{context.slot_key}': "
f"strategies={strategies}, tenant={context.tenant_id}"
)
for idx, strategy in enumerate(strategies):
step_start = time.time()
logger.info(
f"[SlotStrategyExecutor] Executing step {idx + 1}/{len(strategies)}: "
f"slot_key={context.slot_key}, strategy={strategy}"
)
step_result = await self._execute_single_strategy(strategy, context)
step_result.execution_time_ms = (time.time() - step_start) * 1000
steps.append(step_result)
if step_result.success:
total_time = (time.time() - start_time) * 1000
logger.info(
f"[SlotStrategyExecutor] Strategy chain succeeded at step {idx + 1}: "
f"slot_key={context.slot_key}, strategy={strategy}, "
f"total_time_ms={total_time:.2f}"
)
return StrategyChainResult(
slot_key=context.slot_key,
success=True,
final_value=step_result.value,
final_strategy=strategy,
steps=steps,
total_execution_time_ms=total_time,
)
else:
logger.warning(
f"[SlotStrategyExecutor] Step {idx + 1} failed: "
f"slot_key={context.slot_key}, strategy={strategy}, "
f"failure_type={step_result.failure_type}, "
f"reason={step_result.failure_reason}"
)
# 全部策略失败
total_time = (time.time() - start_time) * 1000
logger.warning(
f"[SlotStrategyExecutor] All strategies failed for slot '{context.slot_key}': "
f"attempted={len(strategies)}, total_time_ms={total_time:.2f}"
)
return StrategyChainResult(
slot_key=context.slot_key,
success=False,
final_value=None,
final_strategy=None,
steps=steps,
total_execution_time_ms=total_time,
ask_back_prompt=ask_back_prompt,
)
async def _execute_single_strategy(
self,
strategy: str,
context: ExtractContext,
) -> StrategyStepResult:
"""
执行单个提取策略
Args:
strategy: 策略类型
context: 提取上下文
Returns:
StrategyStepResult: 单步执行结果
"""
extractor = self._extractors.get(strategy)
if not extractor:
return StrategyStepResult(
strategy=strategy,
success=False,
failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR,
failure_reason=f"Unknown strategy: {strategy}",
)
try:
# 执行提取
value = await extractor(context)
# 检查结果是否为空
if value is None or value == "":
return StrategyStepResult(
strategy=strategy,
success=False,
failure_type=ExtractFailureType.EXTRACT_EMPTY,
failure_reason="Extracted value is empty",
)
# 执行校验(如果有校验规则)
if context.validation_rule:
is_valid, error_msg = self._validate_value(value, context)
if not is_valid:
return StrategyStepResult(
strategy=strategy,
success=False,
failure_type=ExtractFailureType.EXTRACT_VALIDATION_FAIL,
failure_reason=f"Validation failed: {error_msg}",
)
return StrategyStepResult(
strategy=strategy,
success=True,
value=value,
)
except Exception as e:
logger.exception(
f"[SlotStrategyExecutor] Runtime error in strategy '{strategy}' "
f"for slot '{context.slot_key}': {e}"
)
return StrategyStepResult(
strategy=strategy,
success=False,
failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR,
failure_reason=f"Runtime error: {str(e)}",
)
def _validate_value(self, value: Any, context: ExtractContext) -> tuple[bool, str]:
"""
校验提取的值
Args:
value: 提取的值
context: 提取上下文
Returns:
Tuple of (是否通过, 错误信息)
"""
import re
validation_rule = context.validation_rule
if not validation_rule:
return True, ""
try:
# 尝试作为正则表达式校验
if validation_rule.startswith("^") or validation_rule.endswith("$"):
if re.match(validation_rule, str(value)):
return True, ""
return False, f"Value '{value}' does not match pattern '{validation_rule}'"
# 其他校验规则可以在这里扩展
return True, ""
except re.error as e:
logger.warning(f"[SlotStrategyExecutor] Invalid validation rule pattern: {e}")
return True, "" # 正则错误时放行
async def _default_rule_extract(self, context: ExtractContext) -> Any:
"""默认规则提取实现(占位)"""
# 实际项目中应该调用 VariableExtractor 或其他规则引擎
logger.debug(f"[SlotStrategyExecutor] Default rule extract for '{context.slot_key}'")
return None
async def _default_llm_extract(self, context: ExtractContext) -> Any:
"""默认LLM提取实现占位"""
logger.debug(f"[SlotStrategyExecutor] Default LLM extract for '{context.slot_key}'")
return None
async def _default_user_input_extract(self, context: ExtractContext) -> Any:
"""默认用户输入提取实现"""
# user_input 策略通常表示需要向用户询问,这里返回空表示需要追问
logger.debug(f"[SlotStrategyExecutor] User input required for '{context.slot_key}'")
return None
# 便捷函数
async def execute_extract_strategies(
strategies: list[str],
tenant_id: str,
slot_key: str,
user_input: str,
slot_type: str = "string",
validation_rule: str | None = None,
ask_back_prompt: str | None = None,
history: list[dict[str, str]] | None = None,
rule_extractor: Callable[[ExtractContext], Any] | None = None,
llm_extractor: Callable[[ExtractContext], Any] | None = None,
) -> StrategyChainResult:
"""
便捷函数执行提取策略链
Args:
strategies: 策略链
tenant_id: 租户ID
slot_key: 槽位键名
user_input: 用户输入
slot_type: 槽位类型
validation_rule: 校验规则
ask_back_prompt: 追问提示语
history: 对话历史
rule_extractor: 规则提取器
llm_extractor: LLM提取器
Returns:
StrategyChainResult: 执行结果
"""
executor = SlotStrategyExecutor(
rule_extractor=rule_extractor,
llm_extractor=llm_extractor,
)
context = ExtractContext(
tenant_id=tenant_id,
slot_key=slot_key,
user_input=user_input,
slot_type=slot_type,
validation_rule=validation_rule,
history=history,
)
return await executor.execute_chain(strategies, context, ask_back_prompt)

View File

@ -0,0 +1,572 @@
"""
Slot Validation Service.
槽位校验规则 runtime 生效服务
提供槽位值的运行时校验能力支持
1. 正则表达式校验
2. JSON Schema 校验
3. 类型校验string/number/boolean/enum/array_enum
4. 必填校验
"""
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any
import jsonschema
from jsonschema.exceptions import ValidationError as JsonSchemaValidationError
from app.models.entities import SlotDefinition
logger = logging.getLogger(__name__)
# 错误码定义
class SlotValidationErrorCode:
"""槽位校验错误码"""
SLOT_REQUIRED_MISSING = "SLOT_REQUIRED_MISSING"
SLOT_TYPE_INVALID = "SLOT_TYPE_INVALID"
SLOT_REGEX_MISMATCH = "SLOT_REGEX_MISMATCH"
SLOT_JSON_SCHEMA_MISMATCH = "SLOT_JSON_SCHEMA_MISMATCH"
SLOT_VALIDATION_RULE_INVALID = "SLOT_VALIDATION_RULE_INVALID"
SLOT_ENUM_INVALID = "SLOT_ENUM_INVALID"
SLOT_ARRAY_ENUM_INVALID = "SLOT_ARRAY_ENUM_INVALID"
@dataclass
class ValidationResult:
"""
单个槽位校验结果
Attributes:
ok: 校验是否通过
normalized_value: 归一化后的值如类型转换后
error_code: 错误码校验失败时
error_message: 错误描述校验失败时
ask_back_prompt: 追问提示语校验失败且配置了 ask_back_prompt
"""
ok: bool
normalized_value: Any | None = None
error_code: str | None = None
error_message: str | None = None
ask_back_prompt: str | None = None
@dataclass
class SlotValidationError:
"""
槽位校验错误详情
Attributes:
slot_key: 槽位键名
error_code: 错误码
error_message: 错误描述
ask_back_prompt: 追问提示语
"""
slot_key: str
error_code: str
error_message: str
ask_back_prompt: str | None = None
@dataclass
class BatchValidationResult:
"""
批量槽位校验结果
Attributes:
ok: 是否全部校验通过
errors: 校验错误列表
validated_values: 校验通过的值字典
"""
ok: bool
errors: list[SlotValidationError] = field(default_factory=list)
validated_values: dict[str, Any] = field(default_factory=dict)
class SlotValidationService:
"""
槽位校验服务
负责在槽位值写回前执行校验支持
- 正则表达式校验
- JSON Schema 校验
- 类型校验
- 必填校验
"""
# 支持的槽位类型
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
def __init__(self):
"""初始化槽位校验服务"""
self._schema_cache: dict[str, dict] = {}
def validate_slot_value(
self,
slot_def: dict[str, Any] | SlotDefinition,
value: Any,
tenant_id: str | None = None,
) -> ValidationResult:
"""
校验单个槽位值
校验顺序
1. 必填校验如果 required=true 且值为空
2. 类型校验
3. validation_rule 校验正则或 JSON Schema
Args:
slot_def: 槽位定义dict SlotDefinition 对象
value: 待校验的值
tenant_id: 租户 ID用于日志记录
Returns:
ValidationResult: 校验结果
"""
# 统一转换为 dict 处理
if isinstance(slot_def, SlotDefinition):
slot_dict = {
"slot_key": slot_def.slot_key,
"type": slot_def.type,
"required": slot_def.required,
"validation_rule": slot_def.validation_rule,
"ask_back_prompt": slot_def.ask_back_prompt,
}
else:
slot_dict = slot_def
slot_key = slot_dict.get("slot_key", "unknown")
slot_type = slot_dict.get("type", "string")
required = slot_dict.get("required", False)
validation_rule = slot_dict.get("validation_rule")
ask_back_prompt = slot_dict.get("ask_back_prompt")
# 1. 必填校验
if required and self._is_empty_value(value):
logger.info(
f"[SlotValidation] Required slot missing: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING,
error_message=f"槽位 '{slot_key}' 为必填项",
ask_back_prompt=ask_back_prompt,
)
# 如果值为空且非必填,跳过后续校验
if self._is_empty_value(value):
return ValidationResult(ok=True, normalized_value=value)
# 2. 类型校验
type_result = self._validate_type(slot_dict, value, tenant_id)
if not type_result.ok:
return ValidationResult(
ok=False,
error_code=type_result.error_code,
error_message=type_result.error_message,
ask_back_prompt=ask_back_prompt,
)
normalized_value = type_result.normalized_value
# 3. validation_rule 校验
if validation_rule and str(validation_rule).strip():
rule_result = self._validate_rule(
slot_dict, normalized_value, tenant_id
)
if not rule_result.ok:
return ValidationResult(
ok=False,
error_code=rule_result.error_code,
error_message=rule_result.error_message,
ask_back_prompt=ask_back_prompt,
)
normalized_value = rule_result.normalized_value or normalized_value
logger.debug(
f"[SlotValidation] Slot validation passed: "
f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}"
)
return ValidationResult(ok=True, normalized_value=normalized_value)
def validate_slots(
self,
slot_defs: list[dict[str, Any] | SlotDefinition],
values: dict[str, Any],
tenant_id: str | None = None,
) -> BatchValidationResult:
"""
批量校验多个槽位值
Args:
slot_defs: 槽位定义列表
values: 槽位值字典 {slot_key: value}
tenant_id: 租户 ID用于日志记录
Returns:
BatchValidationResult: 批量校验结果
"""
errors: list[SlotValidationError] = []
validated_values: dict[str, Any] = {}
# 构建 slot_def 映射
slot_def_map: dict[str, dict[str, Any] | SlotDefinition] = {}
for slot_def in slot_defs:
if isinstance(slot_def, SlotDefinition):
slot_def_map[slot_def.slot_key] = slot_def
else:
slot_def_map[slot_def.get("slot_key", "")] = slot_def
# 校验每个提供的值
for slot_key, value in values.items():
slot_def = slot_def_map.get(slot_key)
if not slot_def:
# 未定义槽位,跳过校验(允许动态槽位)
validated_values[slot_key] = value
continue
result = self.validate_slot_value(slot_def, value, tenant_id)
if result.ok:
validated_values[slot_key] = result.normalized_value
else:
errors.append(
SlotValidationError(
slot_key=slot_key,
error_code=result.error_code or "UNKNOWN_ERROR",
error_message=result.error_message or "校验失败",
ask_back_prompt=result.ask_back_prompt,
)
)
# 检查必填槽位是否缺失
for slot_def in slot_defs:
if isinstance(slot_def, SlotDefinition):
slot_key = slot_def.slot_key
required = slot_def.required
ask_back_prompt = slot_def.ask_back_prompt
else:
slot_key = slot_def.get("slot_key", "")
required = slot_def.get("required", False)
ask_back_prompt = slot_def.get("ask_back_prompt")
if required and slot_key not in values:
# 检查是否已经有该错误
if not any(e.slot_key == slot_key for e in errors):
errors.append(
SlotValidationError(
slot_key=slot_key,
error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING,
error_message=f"槽位 '{slot_key}' 为必填项",
ask_back_prompt=ask_back_prompt,
)
)
return BatchValidationResult(
ok=len(errors) == 0,
errors=errors,
validated_values=validated_values,
)
def _is_empty_value(self, value: Any) -> bool:
"""判断值是否为空"""
if value is None:
return True
if isinstance(value, str) and not value.strip():
return True
if isinstance(value, list) and len(value) == 0:
return True
return False
def _validate_type(
self,
slot_def: dict[str, Any],
value: Any,
tenant_id: str | None,
) -> ValidationResult:
"""
类型校验
Args:
slot_def: 槽位定义字典
value: 待校验的值
tenant_id: 租户 ID
Returns:
ValidationResult: 校验结果
"""
slot_key = slot_def.get("slot_key", "unknown")
slot_type = slot_def.get("type", "string")
if slot_type not in self.VALID_TYPES:
logger.warning(
f"[SlotValidation] Unknown slot type: "
f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}"
)
# 未知类型不阻止,只记录警告
return ValidationResult(ok=True, normalized_value=value)
try:
if slot_type == "string":
if not isinstance(value, str):
# 尝试转换为字符串
normalized = str(value)
return ValidationResult(ok=True, normalized_value=normalized)
return ValidationResult(ok=True, normalized_value=value)
elif slot_type == "number":
if isinstance(value, bool):
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
error_message=f"槽位 '{slot_key}' 类型应为数字,但得到布尔值",
)
if isinstance(value, int | float):
return ValidationResult(ok=True, normalized_value=value)
# 尝试转换为数字
if isinstance(value, str):
try:
if "." in value:
normalized = float(value)
else:
normalized = int(value)
return ValidationResult(ok=True, normalized_value=normalized)
except ValueError:
pass
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
error_message=f"槽位 '{slot_key}' 类型应为数字",
)
elif slot_type == "boolean":
if isinstance(value, bool):
return ValidationResult(ok=True, normalized_value=value)
if isinstance(value, str):
lower_val = value.lower()
if lower_val in ("true", "1", "yes", "", ""):
return ValidationResult(ok=True, normalized_value=True)
if lower_val in ("false", "0", "no", "", ""):
return ValidationResult(ok=True, normalized_value=False)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
error_message=f"槽位 '{slot_key}' 类型应为布尔值",
)
elif slot_type == "enum":
if not isinstance(value, str):
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
error_message=f"槽位 '{slot_key}' 类型应为字符串(枚举)",
)
# 如果有选项定义,校验值是否在选项中
options = slot_def.get("options") or []
if options and value not in options:
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_ENUM_INVALID,
error_message=f"槽位 '{slot_key}' 的值 '{value}' 不在允许选项 {options}",
)
return ValidationResult(ok=True, normalized_value=value)
elif slot_type == "array_enum":
if not isinstance(value, list):
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
error_message=f"槽位 '{slot_key}' 类型应为数组",
)
# 校验数组元素
options = slot_def.get("options") or []
for item in value:
if not isinstance(item, str):
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID,
error_message=f"槽位 '{slot_key}' 的数组元素应为字符串",
)
if options and item not in options:
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID,
error_message=f"槽位 '{slot_key}' 的值 '{item}' 不在允许选项 {options}",
)
return ValidationResult(ok=True, normalized_value=value)
except Exception as e:
logger.error(
f"[SlotValidation] Type validation error: "
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
error_message=f"槽位 '{slot_key}' 类型校验异常: {str(e)}",
)
return ValidationResult(ok=True, normalized_value=value)
def _validate_rule(
self,
slot_def: dict[str, Any],
value: Any,
tenant_id: str | None,
) -> ValidationResult:
"""
校验规则校验正则或 JSON Schema
判定逻辑
1. 如果规则是 JSON 对象字符串 { [ 开头 JSON Schema 处理
2. 否则按正则表达式处理
Args:
slot_def: 槽位定义字典
value: 待校验的值
tenant_id: 租户 ID
Returns:
ValidationResult: 校验结果
"""
slot_key = slot_def.get("slot_key", "unknown")
validation_rule = str(slot_def.get("validation_rule", "")).strip()
if not validation_rule:
return ValidationResult(ok=True, normalized_value=value)
# 判定是 JSON Schema 还是正则表达式
# JSON Schema 通常以 { 或 [ 开头
is_json_schema = validation_rule.strip().startswith(("{", "["))
if is_json_schema:
return self._validate_json_schema(
slot_key, validation_rule, value, tenant_id
)
else:
return self._validate_regex(
slot_key, validation_rule, value, tenant_id
)
def _validate_regex(
self,
slot_key: str,
pattern: str,
value: Any,
tenant_id: str | None,
) -> ValidationResult:
"""
正则表达式校验
Args:
slot_key: 槽位键名
pattern: 正则表达式
value: 待校验的值
tenant_id: 租户 ID
Returns:
ValidationResult: 校验结果
"""
try:
# 将值转为字符串进行匹配
str_value = str(value) if value is not None else ""
if not re.search(pattern, str_value):
logger.info(
f"[SlotValidation] Regex mismatch: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_REGEX_MISMATCH,
error_message=f"槽位 '{slot_key}' 的值不符合格式要求",
)
return ValidationResult(ok=True, normalized_value=value)
except re.error as e:
logger.warning(
f"[SlotValidation] Invalid regex pattern: "
f"tenant_id={tenant_id}, slot_key={slot_key}, pattern={pattern}, error={e}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法正则)",
)
def _validate_json_schema(
self,
slot_key: str,
schema_str: str,
value: Any,
tenant_id: str | None,
) -> ValidationResult:
"""
JSON Schema 校验
Args:
slot_key: 槽位键名
schema_str: JSON Schema 字符串
value: 待校验的值
tenant_id: 租户 ID
Returns:
ValidationResult: 校验结果
"""
try:
# 解析 JSON Schema
schema = self._schema_cache.get(schema_str)
if schema is None:
schema = json.loads(schema_str)
self._schema_cache[schema_str] = schema
# 执行校验
jsonschema.validate(instance=value, schema=schema)
return ValidationResult(ok=True, normalized_value=value)
except json.JSONDecodeError as e:
logger.warning(
f"[SlotValidation] Invalid JSON schema: "
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法 JSON",
)
except JsonSchemaValidationError as e:
logger.info(
f"[SlotValidation] JSON schema mismatch: "
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e.message}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH,
error_message=f"槽位 '{slot_key}' 的值不符合格式要求: {e.message}",
)
except Exception as e:
logger.error(
f"[SlotValidation] JSON schema validation error: "
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
)
return ValidationResult(
ok=False,
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
error_message=f"槽位 '{slot_key}' 的校验规则执行异常: {str(e)}",
)
def clear_cache(self) -> None:
"""清除 JSON Schema 缓存"""
self._schema_cache.clear()

View File

@ -127,7 +127,8 @@ class ToolCallRecorder:
logger.info(
f"[AC-IDMP-15] Tool call recorded: tool={trace.tool_name}, "
f"type={trace.tool_type.value}, duration_ms={trace.duration_ms}, "
f"status={trace.status.value}, session={session_id}"
f"status={trace.status.value}, session={session_id}, "
f"args_digest={trace.args_digest}, result_digest={trace.result_digest}"
)
def record_success(
@ -153,6 +154,8 @@ class ToolCallRecorder:
auth_applied=auth_applied,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
result_digest=ToolCallTrace.compute_digest(result) if result else None,
arguments=args if isinstance(args, dict) else None,
result=result,
)
self.record(session_id, trace)
return trace
@ -179,6 +182,7 @@ class ToolCallRecorder:
registry_version=registry_version,
auth_applied=auth_applied,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
arguments=args if isinstance(args, dict) else None,
)
self.record(session_id, trace)
return trace
@ -207,6 +211,7 @@ class ToolCallRecorder:
registry_version=registry_version,
auth_applied=auth_applied,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
arguments=args if isinstance(args, dict) else None,
)
self.record(session_id, trace)
return trace
@ -231,6 +236,7 @@ class ToolCallRecorder:
error_code=reason,
registry_version=registry_version,
args_digest=ToolCallTrace.compute_digest(args) if args else None,
arguments=args if isinstance(args, dict) else None,
)
self.record(session_id, trace)
return trace

View File

@ -0,0 +1,111 @@
"""
Tool definition converter for Function Calling.
Converts ToolRegistry definitions to LLM ToolDefinition format.
"""
import logging
from typing import Any
from app.services.llm.base import ToolDefinition
from app.services.mid.tool_registry import ToolDefinition as RegistryToolDefinition
logger = logging.getLogger(__name__)
def convert_tool_to_llm_format(tool: RegistryToolDefinition) -> ToolDefinition:
"""
Convert ToolRegistry tool definition to LLM ToolDefinition format.
Args:
tool: Tool definition from ToolRegistry
Returns:
ToolDefinition for Function Calling
"""
meta = tool.metadata or {}
parameters = meta.get("parameters", {
"type": "object",
"properties": {},
"required": [],
})
if not isinstance(parameters, dict):
parameters = {
"type": "object",
"properties": {},
"required": [],
}
if "type" not in parameters:
parameters["type"] = "object"
if "properties" not in parameters:
parameters["properties"] = {}
if "required" not in parameters:
parameters["required"] = []
properties = parameters.get("properties", {})
if "tenant_id" in properties:
properties = {k: v for k, v in properties.items() if k != "tenant_id"}
if "user_id" in properties:
properties = {k: v for k, v in properties.items() if k != "user_id"}
if "session_id" in properties:
properties = {k: v for k, v in properties.items() if k != "session_id"}
parameters["properties"] = properties
required = parameters.get("required", [])
required = [r for r in required if r not in ("tenant_id", "user_id", "session_id")]
parameters["required"] = required
return ToolDefinition(
name=tool.name,
description=tool.description,
parameters=parameters,
)
def convert_tools_to_llm_format(tools: list[RegistryToolDefinition]) -> list[ToolDefinition]:
"""
Convert multiple tool definitions to LLM format.
Args:
tools: List of tool definitions from ToolRegistry
Returns:
List of ToolDefinition for Function Calling
"""
return [convert_tool_to_llm_format(tool) for tool in tools]
def build_tool_result_message(
tool_call_id: str,
tool_name: str,
result: dict[str, Any],
tool_guide: str | None = None,
) -> dict[str, str]:
"""
Build a tool result message for the conversation.
Args:
tool_call_id: ID of the tool call
tool_name: Name of the tool
result: Tool execution result
tool_guide: Optional tool usage guide to append
Returns:
Message dict with role='tool'
"""
if isinstance(result, dict):
result_copy = {k: v for k, v in result.items() if k != "_tool_guide"}
content = str(result_copy)
else:
content = str(result)
if tool_guide:
content = f"{content}\n\n---\n{tool_guide}"
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}

View File

@ -0,0 +1,313 @@
"""
Tool Guide Registry for Mid Platform.
Provides tool-based usage guidance with caching support.
Tool guides are usage manuals for tools, loaded on-demand with metadata scanning.
This separates tool definitions (Function Calling) from usage guides (Tool Guides).
"""
import logging
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
TOOLS_DIR = Path(__file__).parent.parent.parent.parent / "tools"
@dataclass
class ToolGuideMetadata:
"""Lightweight tool guide metadata for quick scanning (~100 tokens)."""
name: str
description: str
triggers: list[str] = field(default_factory=list)
anti_triggers: list[str] = field(default_factory=list)
tools: list[str] = field(default_factory=list)
@dataclass
class ToolGuideDefinition:
"""Full tool guide definition with complete content."""
name: str
description: str
triggers: list[str]
anti_triggers: list[str]
tools: list[str]
content: str
raw_markdown: str
class ToolGuideRegistry:
"""
Tool guide registry with caching support.
Features:
- Load tool guides from .md files
- Cache tool guides in memory for high-frequency access
- Provide lightweight metadata for quick scanning
- Provide full content on demand
"""
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, tools_dir: Path | None = None):
if ToolGuideRegistry._initialized:
return
self._tools_dir = tools_dir or TOOLS_DIR
self._tool_guides: dict[str, ToolGuideDefinition] = {}
self._metadata_cache: dict[str, ToolGuideMetadata] = {}
self._tool_to_guide: dict[str, str] = {}
self._loaded = False
ToolGuideRegistry._initialized = True
logger.info(f"[ToolGuideRegistry] Initialized with tools_dir={self._tools_dir}")
def load_tools(self, force_reload: bool = False) -> None:
"""
Load all tool guides from tools directory into cache.
Args:
force_reload: Force reload even if already loaded
"""
if self._loaded and not force_reload:
logger.debug("[ToolGuideRegistry] Tool guides already loaded, skipping")
return
if not self._tools_dir.exists():
logger.warning(f"[ToolGuideRegistry] Tools directory not found: {self._tools_dir}")
return
self._tool_guides.clear()
self._metadata_cache.clear()
self._tool_to_guide.clear()
for tool_file in self._tools_dir.glob("*.md"):
try:
tool_guide = self._parse_tool_file(tool_file)
if tool_guide:
self._tool_guides[tool_guide.name] = tool_guide
self._metadata_cache[tool_guide.name] = ToolGuideMetadata(
name=tool_guide.name,
description=tool_guide.description,
triggers=tool_guide.triggers,
anti_triggers=tool_guide.anti_triggers,
tools=tool_guide.tools,
)
for tool_name in tool_guide.tools:
self._tool_to_guide[tool_name] = tool_guide.name
logger.info(f"[ToolGuideRegistry] Loaded tool guide: {tool_guide.name} (tools: {tool_guide.tools})")
except Exception as e:
logger.error(f"[ToolGuideRegistry] Failed to load tool guide from {tool_file}: {e}")
self._loaded = True
logger.info(f"[ToolGuideRegistry] Loaded {len(self._tool_guides)} tool guides")
def _parse_tool_file(self, file_path: Path) -> ToolGuideDefinition | None:
"""
Parse a tool guide markdown file.
Expected format:
---
name: tool_name
description: Tool description
triggers:
- trigger 1
- trigger 2
anti_triggers:
- anti trigger 1
tools:
- tool_name
---
## Usage Guide
...
"""
content = file_path.read_text(encoding="utf-8")
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)$", content, re.DOTALL)
if not frontmatter_match:
logger.warning(f"[ToolGuideRegistry] No frontmatter found in {file_path}")
return None
frontmatter_text = frontmatter_match.group(1)
body = frontmatter_match.group(2)
metadata: dict[str, Any] = {}
current_key = None
current_list: list[str] | None = None
for line in frontmatter_text.split("\n"):
if not line.strip():
continue
key_match = re.match(r"^(\w+):\s*(.*)$", line)
if key_match:
current_key = key_match.group(1)
value = key_match.group(2).strip()
if value:
metadata[current_key] = value
current_list = None
else:
current_list = []
metadata[current_key] = current_list
elif line.startswith(" - ") and current_list is not None:
current_list.append(line[4:].strip())
name = metadata.get("name", file_path.stem)
description = metadata.get("description", "")
triggers = metadata.get("triggers", [])
anti_triggers = metadata.get("anti_triggers", [])
tools = metadata.get("tools", [])
if isinstance(triggers, str):
triggers = [triggers]
if isinstance(anti_triggers, str):
anti_triggers = [anti_triggers]
if isinstance(tools, str):
tools = [tools]
return ToolGuideDefinition(
name=name,
description=description,
triggers=triggers,
anti_triggers=anti_triggers,
tools=tools,
content=body.strip(),
raw_markdown=content,
)
def get_tool_guide(self, name: str) -> ToolGuideDefinition | None:
"""Get full tool guide definition by name."""
if not self._loaded:
self.load_tools()
return self._tool_guides.get(name)
def get_tool_metadata(self, name: str) -> ToolGuideMetadata | None:
"""Get lightweight tool guide metadata by name."""
if not self._loaded:
self.load_tools()
return self._metadata_cache.get(name)
def get_guide_for_tool(self, tool_name: str) -> ToolGuideDefinition | None:
"""Get tool guide associated with a tool."""
if not self._loaded:
self.load_tools()
guide_name = self._tool_to_guide.get(tool_name)
if guide_name:
return self._tool_guides.get(guide_name)
return None
def list_tools(self) -> list[str]:
"""List all tool guide names."""
if not self._loaded:
self.load_tools()
return list(self._tool_guides.keys())
def list_tool_metadata(self) -> list[ToolGuideMetadata]:
"""List all tool guide metadata (lightweight)."""
if not self._loaded:
self.load_tools()
return list(self._metadata_cache.values())
def build_tools_prompt_section(self, tool_names: list[str] | None = None) -> str:
"""
Build tools section for ReAct prompt.
Args:
tool_names: If provided, only include tools for these names.
If None, include all tools.
Returns:
Formatted tools section string
"""
if not self._loaded:
self.load_tools()
if not self._tool_guides:
return ""
tools_to_include: list[ToolGuideDefinition] = []
if tool_names:
for tool_name in tool_names:
tool_guide = self.get_guide_for_tool(tool_name)
if tool_guide and tool_guide not in tools_to_include:
tools_to_include.append(tool_guide)
else:
tools_to_include = list(self._tool_guides.values())
if not tools_to_include:
return ""
lines = ["## 工具使用指南", ""]
lines.append("以下是每个工具的详细使用说明:")
lines.append("")
for tool_guide in tools_to_include:
lines.append(f"### {tool_guide.name}")
lines.append(f"描述: {tool_guide.description}")
if tool_guide.triggers:
lines.append("触发条件:")
for trigger in tool_guide.triggers:
lines.append(f" - {trigger}")
if tool_guide.anti_triggers:
lines.append("不应触发:")
for anti in tool_guide.anti_triggers:
lines.append(f" - {anti}")
lines.append("")
lines.append(tool_guide.content)
lines.append("")
return "\n".join(lines)
def build_compact_tools_section(self) -> str:
"""
Build compact tools section with only name and description.
This is used for the initial tool list, with full guidance loaded separately.
"""
if not self._loaded:
self.load_tools()
if not self._metadata_cache:
return "当前没有可用的工具使用指南。"
lines = ["## 可用工具列表", "", "以下是你可以使用的工具:", ""]
for meta in self._metadata_cache.values():
lines.append(f"- **{meta.name}**: {meta.description}")
if meta.tools:
lines.append(f" 关联工具: {', '.join(meta.tools)}")
return "\n".join(lines)
_tool_guide_registry: ToolGuideRegistry | None = None
def get_tool_guide_registry() -> ToolGuideRegistry:
"""Get global tool guide registry instance."""
global _tool_guide_registry
if _tool_guide_registry is None:
_tool_guide_registry = ToolGuideRegistry()
return _tool_guide_registry
def init_tool_guide_registry(tools_dir: Path | None = None) -> ToolGuideRegistry:
"""Initialize and return tool guide registry."""
global _tool_guide_registry
_tool_guide_registry = ToolGuideRegistry(tools_dir=tools_dir)
_tool_guide_registry.load_tools()
return _tool_guide_registry

View File

@ -117,172 +117,124 @@ class ToolRegistry:
self._tools[name] = tool
logger.info(
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
f"version={version}, auth_required={auth_required}"
f"[AC-IDMP-19] Registered tool: {name} v{version} "
f"(type={tool_type.value}, auth={auth_required}, timeout={timeout_ms}ms)"
)
return tool
def unregister(self, name: str) -> bool:
"""Unregister a tool."""
if name in self._tools:
del self._tools[name]
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
return True
return False
def get_tool(self, name: str) -> ToolDefinition | None:
"""Get tool definition by name."""
"""Get tool by name."""
return self._tools.get(name)
def list_tools(
self,
tool_type: ToolType | None = None,
enabled_only: bool = True,
) -> list[ToolDefinition]:
"""List registered tools, optionally filtered."""
tools = list(self._tools.values())
def list_tools(self) -> list[str]:
"""List all registered tool names."""
return list(self._tools.keys())
if tool_type:
tools = [t for t in tools if t.tool_type == tool_type]
def get_all_tools(self) -> list[ToolDefinition]:
"""Get all registered tools."""
return list(self._tools.values())
if enabled_only:
tools = [t for t in tools if t.enabled]
def is_enabled(self, name: str) -> bool:
"""Check if tool is enabled."""
tool = self._tools.get(name)
return tool.enabled if tool else False
return tools
def enable_tool(self, name: str) -> bool:
"""Enable a tool."""
def set_enabled(self, name: str, enabled: bool) -> bool:
"""Enable or disable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = True
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
return True
return False
def disable_tool(self, name: str) -> bool:
"""Disable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = False
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
tool.enabled = enabled
logger.info(f"[AC-IDMP-19] Tool {name} {'enabled' if enabled else 'disabled'}")
return True
return False
async def execute(
self,
tool_name: str,
args: dict[str, Any],
auth_context: dict[str, Any] | None = None,
name: str,
**kwargs: Any,
) -> ToolExecutionResult:
"""
[AC-IDMP-19] Execute a tool with governance.
Execute a tool with governance.
Args:
tool_name: Tool name to execute
args: Tool arguments
auth_context: Authentication context
name: Tool name
**kwargs: Tool arguments
Returns:
ToolExecutionResult with output and metadata
"""
start_time = time.time()
tool = self._tools.get(tool_name)
tool = self._tools.get(name)
if not tool:
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool not found: {tool_name}",
duration_ms=0,
error=f"Tool not found: {name}",
registry_version=self._version,
)
if not tool.enabled:
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool disabled: {tool_name}",
duration_ms=0,
registry_version=tool.version,
error=f"Tool is disabled: {name}",
registry_version=self._version,
)
auth_applied = False
if tool.auth_required:
if not auth_context:
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
return ToolExecutionResult(
success=False,
error="Authentication required",
duration_ms=int((time.time() - start_time) * 1000),
auth_applied=False,
registry_version=tool.version,
)
auth_applied = True
start_time = time.time()
try:
timeout_seconds = tool.timeout_ms / 1000.0
if not tool.handler:
return ToolExecutionResult(
success=False,
error=f"Tool has no handler: {name}",
registry_version=self._version,
)
result = await asyncio.wait_for(
tool.handler(**args) if tool.handler else asyncio.sleep(0),
timeout=timeout_seconds,
result = await self._timeout_governor.execute_with_timeout(
lambda: tool.handler(**kwargs),
timeout_ms=tool.timeout_ms,
)
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
f"duration_ms={duration_ms}, success=True"
)
return ToolExecutionResult(
success=True,
output=result,
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
auth_applied=tool.auth_required,
registry_version=self._version,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
f"duration_ms={duration_ms}"
)
return ToolExecutionResult(
success=False,
error=f"Tool timeout after {tool.timeout_ms}ms",
error=f"Tool execution timeout after {tool.timeout_ms}ms",
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
registry_version=self._version,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
)
logger.error(f"[AC-IDMP-19] Tool execution error: {name} - {e}")
return ToolExecutionResult(
success=False,
error=str(e),
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
registry_version=self._version,
)
def create_trace(
def build_trace(
self,
tool_name: str,
args: dict[str, Any],
result: ToolExecutionResult,
args_digest: str | None = None,
) -> ToolCallTrace:
"""
[AC-IDMP-19] Create ToolCallTrace from execution result.
"""
tool = self._tools.get(tool_name)
"""Build a tool call trace from execution result."""
import hashlib
args_digest = hashlib.md5(str(args).encode()).hexdigest()[:8]
return ToolCallTrace(
tool_name=tool_name,
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
tool_type=tool.tool_type if (tool := self._tools.get(tool_name)) else ToolType.INTERNAL,
registry_version=result.registry_version,
auth_applied=result.auth_applied,
duration_ms=result.duration_ms,
@ -293,6 +245,8 @@ class ToolRegistry:
error_code=result.error if not result.success else None,
args_digest=args_digest,
result_digest=str(result.output)[:100] if result.output else None,
arguments=args,
result=result.output,
)
def get_governance_report(self) -> dict[str, Any]:

View File

@ -23,9 +23,9 @@ RAG Optimization (rag-optimization/spec.md):
"""
import logging
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from sse_starlette.sse import ServerSentEvent
@ -46,7 +46,6 @@ from app.services.flow.engine import FlowEngine
from app.services.guardrail.behavior_service import BehaviorRuleService
from app.services.guardrail.input_scanner import InputScanner
from app.services.guardrail.output_filter import OutputFilter
from app.services.guardrail.word_service import ForbiddenWordService
from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
@ -90,6 +89,8 @@ class GenerationContext:
10. confidence_result: Confidence calculation result
11. messages_saved: Whether messages were saved
12. final_response: Final ChatResponse
[v0.8.0] Extended with route_trace for hybrid routing observability.
"""
tenant_id: str
session_id: str
@ -115,6 +116,11 @@ class GenerationContext:
target_kb_ids: list[str] | None = None
behavior_rules: list[str] = field(default_factory=list)
# [v0.8.0] Hybrid routing fields
route_trace: dict[str, Any] | None = None
fusion_confidence: float | None = None
fusion_decision_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
execution_steps: list[dict[str, Any]] = field(default_factory=list)
@ -487,7 +493,7 @@ class OrchestratorService:
finish_reason="flow_step",
)
ctx.diagnostics["flow_handled"] = True
logger.info(f"[AC-AISVC-75] Flow provided reply, skipping LLM")
logger.info("[AC-AISVC-75] Flow provided reply, skipping LLM")
else:
ctx.diagnostics["flow_check_enabled"] = True
@ -501,8 +507,8 @@ class OrchestratorService:
"""
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
[v0.8.0] Upgraded to use match_hybrid() for hybrid routing.
"""
# Skip if flow already handled the request
if ctx.diagnostics.get("flow_handled"):
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
return
@ -513,7 +519,6 @@ class OrchestratorService:
return
try:
# Load enabled rules ordered by priority
async with get_session() as session:
from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session)
@ -524,33 +529,64 @@ class OrchestratorService:
ctx.diagnostics["intent_matched"] = False
return
# Match intent
ctx.intent_match = self._intent_router.match(
fusion_result = await self._intent_router.match_hybrid(
message=ctx.current_message,
rules=rules,
tenant_id=ctx.tenant_id,
)
if ctx.intent_match:
ctx.route_trace = fusion_result.trace.to_dict()
ctx.fusion_confidence = fusion_result.final_confidence
ctx.fusion_decision_reason = fusion_result.decision_reason
if fusion_result.final_intent:
ctx.intent_match = type(
"IntentMatchResult",
(),
{
"rule": fusion_result.final_intent,
"match_type": fusion_result.decision_reason,
"matched": "",
"to_dict": lambda: {
"rule_id": str(fusion_result.final_intent.id),
"rule_name": fusion_result.final_intent.name,
"match_type": fusion_result.decision_reason,
"matched": "",
"response_type": fusion_result.final_intent.response_type,
"target_kb_ids": (
fusion_result.final_intent.target_kb_ids or []
),
"flow_id": (
str(fusion_result.final_intent.flow_id)
if fusion_result.final_intent.flow_id else None
),
"fixed_reply": fusion_result.final_intent.fixed_reply,
"transfer_message": fusion_result.final_intent.transfer_message,
},
},
)()
logger.info(
f"[AC-AISVC-69] Intent matched: rule={ctx.intent_match.rule.name}, "
f"response_type={ctx.intent_match.rule.response_type}"
f"[AC-AISVC-69] Intent matched: rule={fusion_result.final_intent.name}, "
f"response_type={fusion_result.final_intent.response_type}, "
f"decision={fusion_result.decision_reason}, "
f"confidence={fusion_result.final_confidence:.3f}"
)
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
# Increment hit count
async with get_session() as session:
rule_service = IntentRuleService(session)
await rule_service.increment_hit_count(
tenant_id=ctx.tenant_id,
rule_id=ctx.intent_match.rule.id,
rule_id=fusion_result.final_intent.id,
)
# Route based on response_type
if ctx.intent_match.rule.response_type == "fixed":
# Fixed reply - skip LLM
rule = fusion_result.final_intent
if rule.response_type == "fixed":
ctx.llm_response = LLMResponse(
content=ctx.intent_match.rule.fixed_reply or "收到您的消息。",
content=rule.fixed_reply or "收到您的消息。",
model="intent_fixed",
usage={},
finish_reason="intent_fixed",
@ -558,20 +594,18 @@ class OrchestratorService:
ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
elif ctx.intent_match.rule.response_type == "rag":
# RAG with target KBs
ctx.target_kb_ids = ctx.intent_match.rule.target_kb_ids or []
elif rule.response_type == "rag":
ctx.target_kb_ids = rule.target_kb_ids or []
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
elif ctx.intent_match.rule.response_type == "flow":
# Start script flow
if ctx.intent_match.rule.flow_id and self._flow_engine:
elif rule.response_type == "flow":
if rule.flow_id and self._flow_engine:
async with get_session() as session:
flow_engine = FlowEngine(session)
instance, first_step = await flow_engine.start(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
flow_id=ctx.intent_match.rule.flow_id,
flow_id=rule.flow_id,
)
if first_step:
ctx.llm_response = LLMResponse(
@ -583,10 +617,9 @@ class OrchestratorService:
ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
elif ctx.intent_match.rule.response_type == "transfer":
# Transfer to human
elif rule.response_type == "transfer":
ctx.llm_response = LLMResponse(
content=ctx.intent_match.rule.transfer_message or "正在为您转接人工客服...",
content=rule.transfer_message or "正在为您转接人工客服...",
model="intent_transfer",
usage={},
finish_reason="intent_transfer",
@ -600,9 +633,25 @@ class OrchestratorService:
ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
if fusion_result.need_clarify:
ctx.diagnostics["need_clarify"] = True
ctx.diagnostics["clarify_candidates"] = [
{"id": str(r.id), "name": r.name}
for r in (fusion_result.clarify_candidates or [])
]
logger.info(
f"[AC-AISVC-121] Low confidence, need clarify: "
f"confidence={fusion_result.final_confidence:.3f}, "
f"candidates={len(fusion_result.clarify_candidates or [])}"
)
else:
ctx.diagnostics["intent_match_enabled"] = True
ctx.diagnostics["intent_matched"] = False
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
logger.info(
f"[AC-AISVC-69] No intent matched, decision={fusion_result.decision_reason}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
@ -1122,6 +1171,7 @@ class OrchestratorService:
[AC-AISVC-02] Build final ChatResponse from generation context.
Step 12 of the 12-step pipeline.
Uses filtered_reply from Step 9.
[v0.8.0] Includes route_trace in response metadata.
"""
# Use filtered_reply if available, otherwise use llm_response.content
if ctx.filtered_reply:
@ -1142,6 +1192,10 @@ class OrchestratorService:
"execution_steps": ctx.execution_steps,
}
# [v0.8.0] Include route_trace in response metadata
if ctx.route_trace:
response_metadata["route_trace"] = ctx.route_trace
return ChatResponse(
reply=reply,
confidence=confidence,

View File

@ -178,6 +178,9 @@ class PromptTemplateService:
current_version = v
break
# Get latest version for current_content (not just published)
latest_version = versions[0] if versions else None
return {
"id": str(template.id),
"name": template.name,
@ -185,6 +188,8 @@ class PromptTemplateService:
"description": template.description,
"is_default": template.is_default,
"metadata": template.metadata_,
"current_content": latest_version.system_instruction if latest_version else None,
"variables": latest_version.variables if latest_version else [],
"current_version": {
"version": current_version.version,
"status": current_version.status,

View File

@ -2,6 +2,7 @@
Retrieval module for AI Service.
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
[AC-AISVC-RES-01~15] Strategy routing and mode routing for retrieval pipeline.
"""
from app.services.retrieval.base import (
@ -32,6 +33,29 @@ from app.services.retrieval.optimized_retriever import (
get_optimized_retriever,
)
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
StrategyType,
StrategyResult,
)
from app.services.retrieval.strategy_router import (
RollbackRecord,
StrategyRouter,
get_strategy_router,
)
from app.services.retrieval.mode_router import (
ComplexityAnalyzer,
ModeRouter,
ModeRouteResult,
get_mode_router,
)
from app.services.retrieval.strategy_integration import (
RetrievalStrategyIntegration,
RetrievalStrategyResult,
get_retrieval_strategy_integration,
)
__all__ = [
"BaseRetriever",
@ -55,4 +79,19 @@ __all__ = [
"get_knowledge_indexer",
"IndexingProgress",
"IndexingResult",
"RagRuntimeMode",
"RoutingConfig",
"StrategyContext",
"StrategyType",
"StrategyResult",
"RollbackRecord",
"StrategyRouter",
"get_strategy_router",
"ComplexityAnalyzer",
"ModeRouter",
"ModeRouteResult",
"get_mode_router",
"RetrievalStrategyIntegration",
"RetrievalStrategyResult",
"get_retrieval_strategy_integration",
]

View File

@ -28,6 +28,7 @@ class RetrievalContext:
metadata: dict[str, Any] | None = None
tag_filter: "TagFilter | None" = None
kb_ids: list[str] | None = None
metadata_filter: dict[str, Any] | None = None
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
"""获取标签过滤器的字典表示"""

View File

@ -0,0 +1,438 @@
"""
Mode Router for RAG Runtime Mode Selection.
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
Mode Descriptions:
- direct: Low-latency generic retrieval path (single KB call)
- react: Multi-step ReAct retrieval path (high accuracy)
- auto: Automatic selection based on complexity/confidence rules
"""
from __future__ import annotations
import logging
import re
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class ComplexityAnalyzer:
"""
Analyzes query complexity for mode routing decisions.
Complexity factors:
- Query length
- Number of conditions/constraints
- Presence of logical operators (and, or, not)
- Cross-domain indicators
- Multi-step reasoning requirements
"""
short_query_threshold: int = 20
long_query_threshold: int = 100
condition_patterns: list[str] = field(default_factory=lambda: [
r"和|与|及|并且|同时",
r"或者|还是|要么",
r"但是|不过|然而",
r"如果|假如|假设",
r"既.*又",
r"不仅.*而且",
])
reasoning_patterns: list[str] = field(default_factory=lambda: [
r"为什么|原因|理由",
r"怎么|如何|怎样",
r"区别|差异|不同",
r"比较|对比|优劣",
r"分析|评估|判断",
])
cross_domain_patterns: list[str] = field(default_factory=lambda: [
r"跨|多|各个",
r"所有|全部|整体",
r"综合|汇总|统计",
])
def analyze(self, query: str) -> float:
"""
Analyze query complexity and return a score (0.0 ~ 1.0).
Higher score indicates more complex query that may benefit from ReAct mode.
Args:
query: User query text
Returns:
Complexity score (0.0 = simple, 1.0 = very complex)
"""
if not query:
return 0.0
score = 0.0
query_length = len(query)
if query_length < self.short_query_threshold:
score += 0.0
elif query_length > self.long_query_threshold:
score += 0.3
else:
score += 0.15
condition_count = 0
for pattern in self.condition_patterns:
matches = re.findall(pattern, query)
condition_count += len(matches)
if condition_count >= 3:
score += 0.3
elif condition_count >= 2:
score += 0.2
elif condition_count >= 1:
score += 0.1
for pattern in self.reasoning_patterns:
if re.search(pattern, query):
score += 0.15
break
for pattern in self.cross_domain_patterns:
if re.search(pattern, query):
score += 0.15
break
question_marks = query.count("?") + query.count("")
if question_marks >= 2:
score += 0.1
return min(1.0, score)
@dataclass
class ModeRouteResult:
"""Result from mode routing decision."""
mode: RagRuntimeMode
confidence: float
complexity_score: float
should_fallback_to_react: bool = False
fallback_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
class DirectRetrievalExecutor:
"""
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
Single KB call without multi-step reasoning.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval (single KB call).
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class ReactRetrievalExecutor:
"""
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
Uses AgentOrchestrator for multi-step reasoning and KB calls.
"""
def __init__(self, max_steps: int = 5):
self._max_steps = max_steps
async def execute(
self,
ctx: StrategyContext,
config: RoutingConfig,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval (multi-step reasoning).
Returns:
Tuple of (final_answer, retrieval_result, react_context)
"""
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.llm.factory import get_llm_config_manager
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
except Exception as e:
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
llm_client = None
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
timeout_governor = TimeoutGovernor()
orchestrator = AgentOrchestrator(
max_iterations=min(config.react_max_steps, self._max_steps),
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
tenant_id=ctx.tenant_id,
mode=AgentMode.FUNCTION_CALLING,
)
base_context = {
"query": ctx.query,
"metadata_filter": ctx.metadata_filter,
"kb_ids": ctx.kb_ids,
**ctx.additional_context,
}
final_answer, react_ctx, trace = await orchestrator.execute(
user_message=ctx.query,
context=base_context,
)
return final_answer, None, {
"iterations": react_ctx.iteration,
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
"final_answer": final_answer,
}
class ModeRouter:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Mode router for RAG runtime mode selection.
Mode Selection:
- direct: Low-latency generic retrieval (single KB call)
- react: Multi-step ReAct retrieval (high accuracy)
- auto: Automatic selection based on complexity/confidence
Auto Mode Rules:
- Direct conditions:
- Short query, clear intent
- High metadata confidence
- No cross-domain/multi-condition
- React conditions:
- Multi-condition/multi-constraint
- Low metadata confidence
- Need for secondary confirmation or multi-step reasoning
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._complexity_analyzer = ComplexityAnalyzer()
self._direct_executor = DirectRetrievalExecutor()
self._react_executor = ReactRetrievalExecutor(
max_steps=self._config.react_max_steps
)
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
"""
self._config = new_config
self._react_executor._max_steps = new_config.react_max_steps
logger.info(
f"[AC-AISVC-RES-15] ModeRouter config updated: "
f"mode={new_config.rag_runtime_mode.value}, "
f"react_max_steps={new_config.react_max_steps}, "
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
)
def route(
self,
ctx: StrategyContext,
) -> ModeRouteResult:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Route to appropriate mode based on configuration and context.
Args:
ctx: Strategy context with query, metadata, confidence, etc.
Returns:
ModeRouteResult with selected mode and diagnostics
"""
configured_mode = self._config.get_rag_runtime_mode()
if configured_mode == RagRuntimeMode.DIRECT:
logger.info(
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.DIRECT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "direct"},
)
if configured_mode == RagRuntimeMode.REACT:
logger.info(
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "react"},
)
complexity_score = self._complexity_analyzer.analyze(ctx.query)
effective_complexity = max(complexity_score, ctx.complexity_score)
should_use_react = self._config.should_trigger_react_in_auto_mode(
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
)
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
logger.info(
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
f"Auto mode routing: selected={selected_mode.value}, "
f"confidence={ctx.metadata_confidence:.2f}, "
f"complexity={effective_complexity:.2f}, "
f"tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=selected_mode,
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
diagnostics={
"configured_mode": "auto",
"analyzed_complexity": complexity_score,
"provided_complexity": ctx.complexity_score,
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
"react_trigger_complexity": self._config.react_trigger_complexity_score,
},
)
async def execute_direct(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval mode.
"""
return await self._direct_executor.execute(ctx)
async def execute_react(
self,
ctx: StrategyContext,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval mode.
"""
return await self._react_executor.execute(ctx, self._config)
async def execute_with_fallback(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
"""
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
"""
route_result = self.route(ctx)
if route_result.mode == RagRuntimeMode.DIRECT:
retrieval_result = await self._direct_executor.execute(ctx)
max_score = 0.0
if retrieval_result and retrieval_result.hits:
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
if self._config.should_fallback_direct_to_react(max_score):
logger.info(
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
)
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return (
None,
final_answer,
ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=max_score,
complexity_score=route_result.complexity_score,
should_fallback_to_react=True,
fallback_reason="low_confidence",
diagnostics={
**route_result.diagnostics,
"fallback_from": "direct",
"direct_confidence": max_score,
},
),
)
return retrieval_result, None, route_result
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return None, final_answer, route_result
_mode_router: ModeRouter | None = None
def get_mode_router() -> ModeRouter:
"""Get or create ModeRouter singleton."""
global _mode_router
if _mode_router is None:
_mode_router = ModeRouter()
return _mode_router
def reset_mode_router() -> None:
"""Reset ModeRouter singleton (for testing)."""
global _mode_router
_mode_router = None

View File

@ -0,0 +1,187 @@
"""
Retrieval and Embedding Strategy Configuration.
[AC-AISVC-RES-01~15] Configuration for strategy routing and mode routing.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class StrategyType(str, Enum):
"""Strategy type for retrieval pipeline selection."""
DEFAULT = "default"
ENHANCED = "enhanced"
class RagRuntimeMode(str, Enum):
"""RAG runtime mode for execution path selection."""
DIRECT = "direct"
REACT = "react"
AUTO = "auto"
@dataclass
class RoutingConfig:
"""
[AC-AISVC-RES-01~15] Routing configuration for strategy and mode selection.
Configuration hierarchy:
1. Strategy selection (default vs enhanced)
2. Mode selection (direct/react/auto)
3. Auto routing rules (complexity/confidence thresholds)
4. Fallback behavior
"""
enabled: bool = True
strategy: StrategyType = StrategyType.DEFAULT
grayscale_percentage: float = 0.0
grayscale_allowlist: list[str] = field(default_factory=list)
rag_runtime_mode: RagRuntimeMode = RagRuntimeMode.AUTO
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
direct_fallback_confidence_threshold: float = 0.4
performance_budget_ms: int = 5000
performance_degradation_threshold: float = 0.2
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03] Determine if enhanced strategy should be used.
Priority:
1. If strategy is explicitly set to ENHANCED, use enhanced
2. If strategy is DEFAULT, use default
3. If grayscale is enabled, check percentage/allowlist
"""
if self.strategy == StrategyType.ENHANCED:
return True
if self.strategy == StrategyType.DEFAULT:
return False
if self.grayscale_percentage > 0:
import hashlib
if tenant_id:
hash_val = int(hashlib.md5(tenant_id.encode()).hexdigest()[:8], 16)
return (hash_val % 100) < (self.grayscale_percentage * 100)
return False
if self.grayscale_allowlist and tenant_id:
return tenant_id in self.grayscale_allowlist
return False
def get_rag_runtime_mode(self) -> RagRuntimeMode:
"""Get the configured RAG runtime mode."""
return self.rag_runtime_mode
def should_fallback_direct_to_react(self, confidence: float) -> bool:
"""
[AC-AISVC-RES-14] Determine if direct mode should fallback to react.
Args:
confidence: Retrieval confidence score (0.0 ~ 1.0)
Returns:
True if fallback should be triggered
"""
if not self.direct_fallback_on_low_confidence:
return False
return confidence < self.direct_fallback_confidence_threshold
def should_trigger_react_in_auto_mode(
self,
confidence: float,
complexity_score: float,
) -> bool:
"""
[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13]
Determine if react mode should be triggered in auto mode.
Direct conditions (优先):
- Short query, clear intent
- High metadata confidence
- No cross-domain/multi-condition
React conditions:
- Multi-condition/multi-constraint
- Low metadata confidence
- Need for secondary confirmation or multi-step reasoning
Args:
confidence: Metadata inference confidence (0.0 ~ 1.0)
complexity_score: Query complexity score (0.0 ~ 1.0)
Returns:
True if react mode should be used
"""
if confidence < self.react_trigger_confidence_threshold:
return True
if complexity_score > self.react_trigger_complexity_score:
return True
return False
def validate(self) -> tuple[bool, list[str]]:
"""
[AC-AISVC-RES-06] Validate configuration consistency.
Returns:
(is_valid, list of error messages)
"""
errors = []
if self.grayscale_percentage < 0 or self.grayscale_percentage > 1.0:
errors.append("grayscale_percentage must be between 0.0 and 1.0")
if self.react_trigger_confidence_threshold < 0 or self.react_trigger_confidence_threshold > 1.0:
errors.append("react_trigger_confidence_threshold must be between 0.0 and 1.0")
if self.react_trigger_complexity_score < 0 or self.react_trigger_complexity_score > 1.0:
errors.append("react_trigger_complexity_score must be between 0.0 and 1.0")
if self.react_max_steps < 3 or self.react_max_steps > 10:
errors.append("react_max_steps must be between 3 and 10")
if self.direct_fallback_confidence_threshold < 0 or self.direct_fallback_confidence_threshold > 1.0:
errors.append("direct_fallback_confidence_threshold must be between 0.0 and 1.0")
if self.performance_budget_ms < 1000:
errors.append("performance_budget_ms must be at least 1000")
if self.performance_degradation_threshold < 0 or self.performance_degradation_threshold > 1.0:
errors.append("performance_degradation_threshold must be between 0.0 and 1.0")
return (len(errors) == 0, errors)
@dataclass
class StrategyContext:
"""Context for strategy routing decision."""
tenant_id: str
query: str
metadata_filter: dict[str, Any] | None = None
metadata_confidence: float = 1.0
complexity_score: float = 0.0
kb_ids: list[str] | None = None
top_k: int = 5
additional_context: dict[str, Any] = field(default_factory=dict)
@dataclass
class StrategyResult:
"""Result from strategy routing."""
strategy: StrategyType
mode: RagRuntimeMode
should_fallback: bool = False
fallback_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)

View File

@ -0,0 +1,102 @@
"""
Retrieval Strategy Module for AI Service.
[AC-AISVC-RES-01~15] 策略化检索与嵌入模块
核心组件
- RetrievalStrategyConfig: 策略配置模型
- BasePipeline: Pipeline 抽象基类
- DefaultPipeline: 默认策略复用现有逻辑
- EnhancedPipeline: 增强策略新端到端流程
- MetadataInferenceService: 元数据推断统一入口
- StrategyRouter: 策略路由器
- ModeRouter: 模式路由器direct/react/auto
- RollbackManager: 回退管理器
"""
from app.services.retrieval.strategy.config import (
FilterMode,
GrayscaleConfig,
HybridRetrievalConfig,
MetadataInferenceConfig,
ModeRouterConfig,
PipelineConfig,
RerankerConfig,
RetrievalStrategyConfig,
RuntimeMode,
StrategyType,
get_strategy_config,
set_strategy_config,
)
from app.services.retrieval.strategy.default_pipeline import (
DefaultPipeline,
get_default_pipeline,
)
from app.services.retrieval.strategy.enhanced_pipeline import (
EnhancedPipeline,
get_enhanced_pipeline,
)
from app.services.retrieval.strategy.metadata_inference import (
InferenceContext,
InferenceResult,
MetadataInferenceService,
)
from app.services.retrieval.strategy.mode_router import (
ModeDecision,
ModeRouter,
get_mode_router,
)
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
MetadataFilterResult,
PipelineContext,
PipelineResult,
)
from app.services.retrieval.strategy.rollback_manager import (
AuditLog,
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.strategy.strategy_router import (
RoutingDecision,
StrategyRouter,
get_strategy_router,
)
__all__ = [
"BasePipeline",
"PipelineContext",
"PipelineResult",
"MetadataFilterResult",
"DefaultPipeline",
"get_default_pipeline",
"EnhancedPipeline",
"get_enhanced_pipeline",
"RetrievalStrategyConfig",
"GrayscaleConfig",
"PipelineConfig",
"RerankerConfig",
"ModeRouterConfig",
"HybridRetrievalConfig",
"MetadataInferenceConfig",
"StrategyType",
"FilterMode",
"RuntimeMode",
"get_strategy_config",
"set_strategy_config",
"MetadataInferenceService",
"InferenceContext",
"InferenceResult",
"StrategyRouter",
"RoutingDecision",
"get_strategy_router",
"ModeRouter",
"ModeDecision",
"get_mode_router",
"RollbackManager",
"RollbackResult",
"RollbackTrigger",
"AuditLog",
"get_rollback_manager",
]

View File

@ -0,0 +1,201 @@
"""
Retrieval Strategy Configuration.
[AC-AISVC-RES-01~15] 检索策略配置模型
"""
import random
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class StrategyType(str, Enum):
"""策略类型。"""
DEFAULT = "default"
ENHANCED = "enhanced"
class RuntimeMode(str, Enum):
"""运行时模式。"""
DIRECT = "direct"
REACT = "react"
AUTO = "auto"
class FilterMode(str, Enum):
"""过滤模式。"""
HARD = "hard"
SOFT = "soft"
NONE = "none"
@dataclass
class GrayscaleConfig:
"""灰度发布配置。【AC-AISVC-RES-03】"""
enabled: bool = False
percentage: float = 0.0
allowlist: list[str] = field(default_factory=list)
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
"""判断是否应该使用增强策略。"""
if not self.enabled:
return False
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
return True
return random.random() * 100 < self.percentage
@dataclass
class HybridRetrievalConfig:
"""混合检索配置。"""
dense_weight: float = 0.7
keyword_weight: float = 0.3
rrf_k: int = 60
enable_keyword: bool = True
keyword_top_k_multiplier: int = 2
@dataclass
class RerankerConfig:
"""重排器配置。【AC-AISVC-RES-08】"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
@dataclass
class ModeRouterConfig:
"""模式路由配置。【AC-AISVC-RES-09~15】"""
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
short_query_threshold: int = 20
def should_use_react(
self,
query: str,
confidence: float | None = None,
complexity_score: float | None = None,
) -> bool:
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
if self.runtime_mode == RuntimeMode.REACT:
return True
if self.runtime_mode == RuntimeMode.DIRECT:
return False
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
return False
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
return True
if confidence and confidence < self.react_trigger_confidence_threshold:
return True
return False
@dataclass
class MetadataInferenceConfig:
"""元数据推断配置。"""
enabled: bool = True
confidence_high_threshold: float = 0.8
confidence_low_threshold: float = 0.5
default_filter_mode: FilterMode = FilterMode.SOFT
cache_ttl_seconds: int = 300
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
"""根据置信度确定过滤模式。"""
if confidence is None:
return FilterMode.NONE
if confidence >= self.confidence_high_threshold:
return FilterMode.HARD
if confidence >= self.confidence_low_threshold:
return FilterMode.SOFT
return FilterMode.NONE
@dataclass
class PipelineConfig:
"""Pipeline 配置。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
two_stage_expand_factor: int = 10
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
@dataclass
class RetrievalStrategyConfig:
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
active_strategy: StrategyType = StrategyType.DEFAULT
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
reranker: RerankerConfig = field(default_factory=RerankerConfig)
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
"max_latency_ms": 2000.0,
"min_success_rate": 0.95,
"max_error_rate": 0.05,
})
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
"""判断是否启用增强策略。"""
if self.active_strategy == StrategyType.ENHANCED:
return True
return self.grayscale.should_use_enhanced(tenant_id, user_id)
def to_dict(self) -> dict[str, Any]:
"""转换为字典。"""
return {
"active_strategy": self.active_strategy.value,
"grayscale": {
"enabled": self.grayscale.enabled,
"percentage": self.grayscale.percentage,
"allowlist": self.grayscale.allowlist,
},
"pipeline": {
"top_k": self.pipeline.top_k,
"score_threshold": self.pipeline.score_threshold,
"min_hits": self.pipeline.min_hits,
"two_stage_enabled": self.pipeline.two_stage_enabled,
},
"reranker": {
"enabled": self.reranker.enabled,
"model": self.reranker.model,
"top_k_after_rerank": self.reranker.top_k_after_rerank,
},
"mode_router": {
"runtime_mode": self.mode_router.runtime_mode.value,
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
},
"metadata_inference": {
"enabled": self.metadata_inference.enabled,
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
},
"performance_thresholds": self.performance_thresholds,
}
_global_config: RetrievalStrategyConfig | None = None
def get_strategy_config() -> RetrievalStrategyConfig:
"""获取全局策略配置。"""
global _global_config
if _global_config is None:
_global_config = RetrievalStrategyConfig()
return _global_config
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
"""设置全局策略配置。"""
global _global_config
_global_config = config

View File

@ -0,0 +1,117 @@
"""
Default Pipeline.
[AC-AISVC-RES-01] 默认策略 Pipeline复用现有逻辑
"""
import logging
import time
from typing import Any
from app.services.retrieval.base import RetrievalContext, RetrievalResult
from app.services.retrieval.optimized_retriever import OptimizedRetriever, get_optimized_retriever
from app.services.retrieval.strategy.config import PipelineConfig
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
PipelineContext,
PipelineResult,
)
logger = logging.getLogger(__name__)
class DefaultPipeline(BasePipeline):
"""
默认策略 PipelineAC-AISVC-RES-01
复用现有 OptimizedRetriever 逻辑保持线上行为不变
"""
def __init__(
self,
config: PipelineConfig | None = None,
optimized_retriever: OptimizedRetriever | None = None,
):
self._config = config or PipelineConfig()
self._optimized_retriever = optimized_retriever
@property
def name(self) -> str:
return "default_pipeline"
@property
def description(self) -> str:
return "默认检索策略,复用现有 OptimizedRetriever 逻辑。"
async def _get_retriever(self) -> OptimizedRetriever:
if self._optimized_retriever is None:
self._optimized_retriever = await get_optimized_retriever()
return self._optimized_retriever
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
"""执行默认检索流程。【AC-AISVC-RES-01】"""
start_time = time.time()
logger.info(
f"[DefaultPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..."
)
try:
retriever = await self._get_retriever()
metadata_filter = None
if ctx.metadata_filter:
metadata_filter = ctx.metadata_filter.filter_dict
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
session_id=ctx.session_id,
metadata_filter=metadata_filter,
kb_ids=ctx.kb_ids,
)
result = await retriever.retrieve(retrieval_ctx)
latency_ms = (time.time() - start_time) * 1000
logger.info(
f"[DefaultPipeline] Retrieval completed: hits={len(result.hits)}, "
f"latency_ms={latency_ms:.2f}"
)
return PipelineResult(
retrieval_result=result,
pipeline_name=self.name,
metadata_filter_applied=metadata_filter is not None,
latency_ms=latency_ms,
diagnostics={
"retriever": "OptimizedRetriever",
**(result.diagnostics or {}),
},
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(f"[DefaultPipeline] Retrieval error: {e}", exc_info=True)
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
async def health_check(self) -> bool:
"""健康检查。"""
try:
retriever = await self._get_retriever()
return await retriever.health_check()
except Exception as e:
logger.error(f"[DefaultPipeline] Health check failed: {e}")
return False
_default_pipeline: DefaultPipeline | None = None
async def get_default_pipeline() -> DefaultPipeline:
"""获取 DefaultPipeline 单例。"""
global _default_pipeline
if _default_pipeline is None:
_default_pipeline = DefaultPipeline()
return _default_pipeline

View File

@ -0,0 +1,364 @@
"""
Enhanced Pipeline.
[AC-AISVC-RES-02] 增强策略 Pipeline新端到端流程
"""
import asyncio
import logging
import re
import time
from dataclasses import dataclass
from typing import Any
import numpy as np
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
from app.services.retrieval.base import RetrievalHit, RetrievalResult
from app.services.retrieval.optimized_retriever import RRFCombiner
from app.services.retrieval.strategy.config import (
HybridRetrievalConfig,
PipelineConfig,
RerankerConfig,
)
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
PipelineContext,
PipelineResult,
)
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class RetrievalCandidate:
"""检索候选结果。"""
id: str
text: str
score: float
vector_score: float = 0.0
keyword_score: float = 0.0
metadata: dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class EnhancedPipeline(BasePipeline):
"""
增强策略 PipelineAC-AISVC-RES-02
新端到端流程
1. Dense 向量检索
2. Keyword 关键词检索
3. RRF 融合排序
4. 可选重排
"""
def __init__(
self,
config: PipelineConfig | None = None,
reranker_config: RerankerConfig | None = None,
qdrant_client: QdrantClient | None = None,
):
self._config = config or PipelineConfig()
self._reranker_config = reranker_config or RerankerConfig()
self._qdrant_client = qdrant_client
self._rrf_combiner = RRFCombiner(k=self._config.hybrid.rrf_k)
self._embedding_provider: NomicEmbeddingProvider | None = None
@property
def name(self) -> str:
return "enhanced_pipeline"
@property
def description(self) -> str:
return "增强检索策略,支持 Dense + Keyword + RRF 组合检索。"
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
from app.services.embedding.factory import get_embedding_config_manager
manager = get_embedding_config_manager()
provider = await manager.get_provider()
if isinstance(provider, NomicEmbeddingProvider):
self._embedding_provider = provider
else:
self._embedding_provider = NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
return self._embedding_provider
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
"""执行增强检索流程。【AC-AISVC-RES-02】"""
start_time = time.time()
logger.info(
f"[EnhancedPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..."
)
try:
provider = await self._get_embedding_provider()
embedding_result = await provider.embed_query(ctx.query)
candidates = await self._hybrid_retrieve(
tenant_id=ctx.tenant_id,
query=ctx.query,
embedding_result=embedding_result,
metadata_filter=ctx.metadata_filter.filter_dict if ctx.metadata_filter else None,
kb_ids=ctx.kb_ids,
)
if self._reranker_config.enabled and ctx.use_reranker:
candidates = await self._rerank(
candidates=candidates,
query=ctx.query,
)
top_k = self._config.top_k
final_candidates = candidates[:top_k]
hits = [
RetrievalHit(
text=c.text,
score=c.score,
source=self.name,
metadata=c.metadata,
)
for c in final_candidates
if c.score >= self._config.score_threshold
]
latency_ms = (time.time() - start_time) * 1000
logger.info(
f"[EnhancedPipeline] Retrieval completed: hits={len(hits)}, "
f"latency_ms={latency_ms:.2f}"
)
result = RetrievalResult(
hits=hits,
diagnostics={
"total_candidates": len(candidates),
"after_rerank": self._reranker_config.enabled and ctx.use_reranker,
},
)
return PipelineResult(
retrieval_result=result,
pipeline_name=self.name,
used_reranker=self._reranker_config.enabled and ctx.use_reranker,
metadata_filter_applied=ctx.metadata_filter is not None,
latency_ms=latency_ms,
diagnostics={
"dense_weight": self._config.hybrid.dense_weight,
"keyword_weight": self._config.hybrid.keyword_weight,
},
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(f"[EnhancedPipeline] Retrieval error: {e}", exc_info=True)
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
async def _hybrid_retrieve(
self,
tenant_id: str,
query: str,
embedding_result: Any,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[RetrievalCandidate]:
"""混合检索Dense + Keyword + RRF。"""
client = await self._get_client()
top_k = self._config.top_k
expand_factor = self._config.hybrid.keyword_top_k_multiplier
vector_task = self._dense_search(
client=client,
tenant_id=tenant_id,
embedding=embedding_result.embedding_full,
top_k=top_k * expand_factor,
metadata_filter=metadata_filter,
kb_ids=kb_ids,
)
keyword_task = self._keyword_search(
client=client,
tenant_id=tenant_id,
query=query,
top_k=top_k * expand_factor,
metadata_filter=metadata_filter,
kb_ids=kb_ids,
) if self._config.hybrid.enable_keyword else asyncio.sleep(0, result=[])
vector_results, keyword_results = await asyncio.gather(
vector_task, keyword_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[EnhancedPipeline] Dense search failed: {vector_results}")
vector_results = []
if isinstance(keyword_results, Exception):
logger.warning(f"[EnhancedPipeline] Keyword search failed: {keyword_results}")
keyword_results = []
combined = self._rrf_combiner.combine(
vector_results=vector_results,
bm25_results=keyword_results,
vector_weight=self._config.hybrid.dense_weight,
bm25_weight=self._config.hybrid.keyword_weight,
)
candidates = []
for item in combined:
candidates.append(RetrievalCandidate(
id=item.get("id", ""),
text=item.get("payload", {}).get("text", ""),
score=item.get("score", 0.0),
vector_score=item.get("vector_score", 0.0),
keyword_score=item.get("bm25_score", 0.0),
metadata=item.get("payload", {}),
))
return candidates
async def _dense_search(
self,
client: QdrantClient,
tenant_id: str,
embedding: list[float],
top_k: int,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Dense 向量检索。"""
try:
results = await client.search(
tenant_id=tenant_id,
query_vector=embedding,
limit=top_k,
vector_name="full",
metadata_filter=metadata_filter,
kb_ids=kb_ids,
)
return results
except Exception as e:
logger.error(f"[EnhancedPipeline] Dense search error: {e}")
return []
async def _keyword_search(
self,
client: QdrantClient,
tenant_id: str,
query: str,
top_k: int,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Keyword 关键词检索。"""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
query_terms = set(re.findall(r'\w+', query.lower()))
results = await qdrant.scroll(
collection_name=collection_name,
limit=top_k * 3,
with_payload=True,
)
scored_results = []
for point in results[0]:
text = point.payload.get("text", "").lower()
text_terms = set(re.findall(r'\w+', text))
overlap = len(query_terms & text_terms)
if overlap > 0:
score = overlap / (len(query_terms) + len(text_terms) - overlap)
scored_results.append({
"id": str(point.id),
"score": score,
"payload": point.payload or {},
})
scored_results.sort(key=lambda x: x["score"], reverse=True)
return scored_results[:top_k]
except Exception as e:
logger.debug(f"[EnhancedPipeline] Keyword search failed: {e}")
return []
async def _rerank(
self,
candidates: list[RetrievalCandidate],
query: str,
) -> list[RetrievalCandidate]:
"""可选重排。"""
if not candidates:
return candidates
try:
provider = await self._get_embedding_provider()
query_embedding = await provider.embed_query(query)
reranked = []
for candidate in candidates:
candidate_text = candidate.text[:500]
if candidate_text:
candidate_embedding = await provider.embed(candidate_text)
similarity = self._cosine_similarity(
query_embedding.embedding_full,
candidate_embedding,
)
candidate.score = similarity
if candidate.score >= self._reranker_config.min_score_threshold:
reranked.append(candidate)
reranked.sort(key=lambda x: x.score, reverse=True)
return reranked[:self._reranker_config.top_k_after_rerank]
except Exception as e:
logger.warning(f"[EnhancedPipeline] Rerank failed: {e}")
return candidates
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""计算余弦相似度。"""
a = np.array(vec1)
b = np.array(vec2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
async def health_check(self) -> bool:
"""健康检查。"""
try:
client = await self._get_client()
qdrant = await client.get_client()
await qdrant.get_collections()
return True
except Exception as e:
logger.error(f"[EnhancedPipeline] Health check failed: {e}")
return False
_enhanced_pipeline: EnhancedPipeline | None = None
async def get_enhanced_pipeline() -> EnhancedPipeline:
"""获取 EnhancedPipeline 单例。"""
global _enhanced_pipeline
if _enhanced_pipeline is None:
_enhanced_pipeline = EnhancedPipeline()
return _enhanced_pipeline

View File

@ -0,0 +1,136 @@
"""
Metadata Inference Service.
[AC-AISVC-RES-04] 元数据推断统一入口
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.mid.metadata_filter_builder import (
FilterBuildResult,
MetadataFilterBuilder,
)
from app.services.retrieval.strategy.config import (
FilterMode,
MetadataInferenceConfig,
)
from app.services.retrieval.strategy.pipeline_base import MetadataFilterResult
logger = logging.getLogger(__name__)
@dataclass
class InferenceContext:
"""元数据推断上下文。"""
tenant_id: str
query: str
session_id: str | None = None
user_id: str | None = None
channel_type: str | None = None
existing_context: dict[str, Any] = field(default_factory=dict)
slot_state: Any = None
@dataclass
class InferenceResult:
"""元数据推断结果。"""
filter_result: MetadataFilterResult
inferred_fields: dict[str, Any] = field(default_factory=dict)
confidence_scores: dict[str, float] = field(default_factory=dict)
overall_confidence: float | None = None
inference_source: str = "unknown"
class MetadataInferenceService:
"""
元数据推断统一入口AC-AISVC-RES-04
职责
1. 统一的元数据推断入口策略无关
2. 根据置信度决定 hard/soft filter 模式
3. 与现有 MetadataFilterBuilder 保持一致
"""
def __init__(
self,
session: AsyncSession,
config: MetadataInferenceConfig | None = None,
):
self._session = session
self._config = config or MetadataInferenceConfig()
self._filter_builder: MetadataFilterBuilder | None = None
async def infer(self, ctx: InferenceContext) -> InferenceResult:
"""执行元数据推断。【AC-AISVC-RES-04】"""
logger.info(
f"[MetadataInference] Starting inference: tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..."
)
if self._filter_builder is None:
self._filter_builder = MetadataFilterBuilder(self._session)
effective_context = dict(ctx.existing_context)
if ctx.slot_state:
effective_context = await self._merge_slot_state(
effective_context, ctx.slot_state
)
build_result = await self._filter_builder.build_filter(
tenant_id=ctx.tenant_id,
context=effective_context,
)
confidence = self._calculate_confidence(build_result, effective_context)
filter_mode = self._config.determine_filter_mode(confidence)
filter_result = MetadataFilterResult(
filter_dict=build_result.applied_filter,
filter_mode=filter_mode,
confidence=confidence,
missing_required_slots=build_result.missing_required_slots,
debug_info=build_result.debug_info,
)
logger.info(
f"[MetadataInference] Inference completed: filter_mode={filter_mode.value}, "
f"confidence={confidence}"
)
return InferenceResult(
filter_result=filter_result,
inferred_fields=build_result.applied_filter,
overall_confidence=confidence,
inference_source="metadata_filter_builder",
)
async def _merge_slot_state(
self, context: dict[str, Any], slot_state: Any
) -> dict[str, Any]:
"""合并槽位状态到上下文。"""
if hasattr(slot_state, 'filled_slots'):
for slot_key, slot_value in slot_state.filled_slots.items():
if slot_key not in context:
context[slot_key] = slot_value
return context
def _calculate_confidence(
self, build_result: FilterBuildResult, context: dict[str, Any]
) -> float | None:
"""计算推断置信度。"""
if build_result.missing_required_slots:
return 0.3
if not build_result.applied_filter:
return None
if not context:
return 0.5
applied_ratio = len(build_result.applied_filter) / max(len(context), 1)
if applied_ratio >= 0.8:
return 0.9
elif applied_ratio >= 0.5:
return 0.7
return 0.5

View File

@ -0,0 +1,118 @@
"""
Mode Router.
[AC-AISVC-RES-09~15] 模式路由器
"""
import logging
from dataclasses import dataclass
from app.services.retrieval.strategy.config import ModeRouterConfig, RuntimeMode
from app.services.retrieval.strategy.pipeline_base import PipelineResult
logger = logging.getLogger(__name__)
@dataclass
class ModeDecision:
"""模式决策结果。"""
mode: RuntimeMode
reason: str
confidence: float | None = None
complexity_score: float | None = None
class ModeRouter:
"""
模式路由器AC-AISVC-RES-09~15
职责
1. 根据 rag_runtime_mode 选择 direct/react/auto 模式
2. auto 模式下根据复杂度与置信度自动选择路由
3. direct 低置信度时触发 react 回退
"""
def __init__(self, config: ModeRouterConfig | None = None):
self._config = config or ModeRouterConfig()
def decide(
self,
query: str,
confidence: float | None = None,
complexity_score: float | None = None,
) -> ModeDecision:
"""决定使用哪种模式。【AC-AISVC-RES-09~13】"""
if self._config.runtime_mode == RuntimeMode.REACT:
return ModeDecision(mode=RuntimeMode.REACT, reason="runtime_mode=react")
if self._config.runtime_mode == RuntimeMode.DIRECT:
return ModeDecision(mode=RuntimeMode.DIRECT, reason="runtime_mode=direct")
calculated_complexity = complexity_score or self._calculate_complexity(query)
if self._should_use_direct(query, confidence, calculated_complexity):
return ModeDecision(
mode=RuntimeMode.DIRECT,
reason="auto: short_query_high_confidence",
confidence=confidence,
complexity_score=calculated_complexity,
)
return ModeDecision(
mode=RuntimeMode.REACT,
reason="auto: complex_or_low_confidence",
confidence=confidence,
complexity_score=calculated_complexity,
)
def should_fallback_to_react(self, direct_result: PipelineResult) -> bool:
"""判断是否应该从 direct 回退到 react。【AC-AISVC-RES-14】"""
if not self._config.direct_fallback_on_low_confidence:
return False
if direct_result.is_empty:
return True
max_score = direct_result.retrieval_result.max_score
if max_score < 0.3:
return True
if direct_result.retrieval_result.hit_count < 2:
return True
return False
def _should_use_direct(
self, query: str, confidence: float | None, complexity_score: float
) -> bool:
if len(query) <= self._config.short_query_threshold:
if confidence and confidence >= self._config.react_trigger_confidence_threshold:
return True
if confidence and confidence < self._config.react_trigger_confidence_threshold:
return False
if complexity_score >= self._config.react_trigger_complexity_score:
return False
return True
def _calculate_complexity(self, query: str) -> float:
score = 0.0
if len(query) > 50:
score += 0.2
if len(query) > 100:
score += 0.2
condition_words = ["", "", "但是", "如果", "同时", "并且", "或者", "以及"]
for word in condition_words:
if word in query:
score += 0.1
return min(score, 1.0)
def get_config(self) -> ModeRouterConfig:
return self._config
def update_config(self, config: ModeRouterConfig) -> None:
self._config = config
_mode_router: ModeRouter | None = None
def get_mode_router() -> ModeRouter:
global _mode_router
if _mode_router is None:
_mode_router = ModeRouter()
return _mode_router

View File

@ -0,0 +1,116 @@
"""
Pipeline Base Classes.
[AC-AISVC-RES-01~15] Pipeline 抽象基类
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
from app.services.retrieval.strategy.config import FilterMode
logger = logging.getLogger(__name__)
@dataclass
class MetadataFilterResult:
"""元数据过滤结果。"""
filter_dict: dict[str, Any] = field(default_factory=dict)
filter_mode: FilterMode = FilterMode.NONE
confidence: float | None = None
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
debug_info: dict[str, Any] = field(default_factory=dict)
@dataclass
class PipelineContext:
"""Pipeline 执行上下文。"""
retrieval_ctx: RetrievalContext
metadata_filter: MetadataFilterResult | None = None
use_reranker: bool = False
use_react: bool = False
react_iteration: int = 0
previous_results: list[RetrievalHit] = field(default_factory=list)
extra: dict[str, Any] = field(default_factory=dict)
@property
def tenant_id(self) -> str:
return self.retrieval_ctx.tenant_id
@property
def query(self) -> str:
return self.retrieval_ctx.query
@property
def session_id(self) -> str | None:
return self.retrieval_ctx.session_id
@property
def kb_ids(self) -> list[str] | None:
return self.retrieval_ctx.kb_ids
@dataclass
class PipelineResult:
"""Pipeline 执行结果。"""
retrieval_result: RetrievalResult
pipeline_name: str = ""
used_reranker: bool = False
used_react: bool = False
react_iterations: int = 0
metadata_filter_applied: bool = False
fallback_triggered: bool = False
fallback_reason: str | None = None
latency_ms: float = 0.0
diagnostics: dict[str, Any] = field(default_factory=dict)
@property
def hits(self) -> list[RetrievalHit]:
return self.retrieval_result.hits
@property
def is_empty(self) -> bool:
return self.retrieval_result.is_empty
class BasePipeline(ABC):
"""Pipeline 抽象基类。【AC-AISVC-RES-01, AC-AISVC-RES-02】"""
@property
@abstractmethod
def name(self) -> str:
"""Pipeline 名称。"""
pass
@property
@abstractmethod
def description(self) -> str:
"""Pipeline 描述。"""
pass
@abstractmethod
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
"""执行检索。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查。"""
pass
def _create_empty_result(
self,
ctx: PipelineContext,
error: str | None = None,
latency_ms: float = 0.0,
) -> PipelineResult:
"""创建空结果。"""
diagnostics = {"error": error} if error else {}
return PipelineResult(
retrieval_result=RetrievalResult(hits=[], diagnostics=diagnostics),
pipeline_name=self.name,
latency_ms=latency_ms,
diagnostics=diagnostics,
)

View File

@ -0,0 +1,301 @@
"""
Retrieval Strategy - Unified Entry Point.
[AC-AISVC-RES-01~15] 检索策略统一入口
整合
- StrategyRouter: 策略路由default/enhanced
- ModeRouter: 模式路由direct/react/auto
- MetadataInferenceService: 元数据推断
- RollbackManager: 回退管理
"""
import logging
import time
from dataclasses import dataclass
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.retrieval.base import RetrievalContext, RetrievalResult
from app.services.retrieval.strategy.config import (
RetrievalStrategyConfig,
RuntimeMode,
StrategyType,
)
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
from app.services.retrieval.strategy.metadata_inference import (
InferenceContext,
MetadataInferenceService,
)
from app.services.retrieval.strategy.mode_router import ModeDecision, ModeRouter
from app.services.retrieval.strategy.pipeline_base import (
MetadataFilterResult,
PipelineContext,
PipelineResult,
)
from app.services.retrieval.strategy.rollback_manager import (
RollbackManager,
RollbackTrigger,
)
from app.services.retrieval.strategy.strategy_router import (
RoutingDecision,
StrategyRouter,
)
logger = logging.getLogger(__name__)
@dataclass
class RetrievalStrategyResult:
"""检索策略执行结果。"""
retrieval_result: RetrievalResult
strategy_used: StrategyType
mode_used: RuntimeMode
metadata_filter: MetadataFilterResult | None
latency_ms: float
diagnostics: dict[str, Any]
class RetrievalStrategy:
"""
检索策略统一入口AC-AISVC-RES-01~15
整合所有策略组件
1. 元数据推断MetadataInferenceService
2. 策略路由StrategyRouter
3. 模式路由ModeRouter
4. 回退管理RollbackManager
使用方式
```python
strategy = RetrievalStrategy(session)
result = await strategy.retrieve(
tenant_id="tenant_1",
query="用户问题",
context={"user_id": "user_1"},
)
```
"""
def __init__(
self,
session: AsyncSession,
config: RetrievalStrategyConfig | None = None,
):
self._session = session
self._config = config or RetrievalStrategyConfig()
self._strategy_router = StrategyRouter(self._config)
self._mode_router = ModeRouter(self._config.mode_router)
self._rollback_manager = RollbackManager(self._config)
self._metadata_inference: MetadataInferenceService | None = None
async def retrieve(
self,
tenant_id: str,
query: str,
context: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
session_id: str | None = None,
user_id: str | None = None,
use_reranker: bool = False,
use_react: bool = False,
) -> RetrievalStrategyResult:
"""
执行检索策略AC-AISVC-RES-01~15
流程
1. 元数据推断
2. 策略路由
3. 模式路由
4. 执行检索
5. 检查是否需要回退
Args:
tenant_id: 租户 ID
query: 查询文本
context: 上下文信息
kb_ids: 知识库 ID 列表
session_id: 会话 ID
user_id: 用户 ID
use_reranker: 是否使用重排
use_react: 是否使用 ReAct 模式
Returns:
RetrievalStrategyResult 包含检索结果和诊断信息
"""
start_time = time.time()
context = context or {}
logger.info(
f"[RetrievalStrategy] Starting retrieval: tenant={tenant_id}, "
f"query={query[:50]}..."
)
try:
metadata_filter = await self._infer_metadata(
tenant_id=tenant_id,
query=query,
context=context,
session_id=session_id,
user_id=user_id,
)
routing_decision = await self._strategy_router.route(tenant_id, user_id)
mode_decision = self._mode_router.decide(
query=query,
confidence=metadata_filter.confidence if metadata_filter else None,
)
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=query,
session_id=session_id,
metadata_filter=metadata_filter.filter_dict if metadata_filter else None,
kb_ids=kb_ids,
)
pipeline_ctx = PipelineContext(
retrieval_ctx=retrieval_ctx,
metadata_filter=metadata_filter,
use_reranker=use_reranker or self._config.reranker.enabled,
use_react=use_react or mode_decision.mode == RuntimeMode.REACT,
)
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
if mode_decision.mode == RuntimeMode.DIRECT and self._mode_router.should_fallback_to_react(pipeline_result):
logger.info("[RetrievalStrategy] Falling back to react mode")
pipeline_ctx.use_react = True
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
mode_decision = ModeDecision(
mode=RuntimeMode.REACT,
reason="fallback_from_direct",
)
latency_ms = (time.time() - start_time) * 1000
self._check_performance(latency_ms, tenant_id)
logger.info(
f"[RetrievalStrategy] Retrieval completed: strategy={routing_decision.strategy.value}, "
f"mode={mode_decision.mode.value}, hits={len(pipeline_result.hits)}, "
f"latency_ms={latency_ms:.2f}"
)
return RetrievalStrategyResult(
retrieval_result=pipeline_result.retrieval_result,
strategy_used=routing_decision.strategy,
mode_used=mode_decision.mode,
metadata_filter=metadata_filter,
latency_ms=latency_ms,
diagnostics={
"routing_reason": routing_decision.reason,
"mode_reason": mode_decision.reason,
"grayscale_hit": routing_decision.grayscale_hit,
**pipeline_result.diagnostics,
},
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(
f"[RetrievalStrategy] Retrieval error: {e}",
exc_info=True
)
self._rollback_manager.rollback(
trigger=RollbackTrigger.ERROR,
reason=str(e),
tenant_id=tenant_id,
)
return RetrievalStrategyResult(
retrieval_result=RetrievalResult(
hits=[],
diagnostics={"error": str(e)},
),
strategy_used=StrategyType.DEFAULT,
mode_used=RuntimeMode.DIRECT,
metadata_filter=None,
latency_ms=latency_ms,
diagnostics={"error": str(e)},
)
async def _infer_metadata(
self,
tenant_id: str,
query: str,
context: dict[str, Any],
session_id: str | None = None,
user_id: str | None = None,
) -> MetadataFilterResult | None:
"""执行元数据推断。"""
try:
if self._metadata_inference is None:
self._metadata_inference = MetadataInferenceService(
self._session,
self._config.metadata_inference,
)
inference_ctx = InferenceContext(
tenant_id=tenant_id,
query=query,
session_id=session_id,
user_id=user_id,
existing_context=context,
)
result = await self._metadata_inference.infer(inference_ctx)
return result.filter_result
except Exception as e:
logger.warning(f"[RetrievalStrategy] Metadata inference failed: {e}")
return None
def _check_performance(self, latency_ms: float, tenant_id: str | None) -> None:
"""检查性能指标,必要时触发回退。"""
self._rollback_manager.check_and_rollback(
metrics={"latency_ms": latency_ms},
tenant_id=tenant_id,
)
def get_config(self) -> RetrievalStrategyConfig:
"""获取当前配置。"""
return self._config
def update_config(self, config: RetrievalStrategyConfig) -> None:
"""更新配置。"""
self._config = config
self._strategy_router.update_config(config)
self._mode_router.update_config(config.mode_router)
self._rollback_manager.update_config(config)
async def health_check(self) -> dict[str, bool]:
"""健康检查。"""
results = {}
try:
default_pipeline = await self._strategy_router._get_default_pipeline()
results["default_pipeline"] = await default_pipeline.health_check()
except Exception:
results["default_pipeline"] = False
try:
enhanced_pipeline = await self._strategy_router._get_enhanced_pipeline()
results["enhanced_pipeline"] = await enhanced_pipeline.health_check()
except Exception:
results["enhanced_pipeline"] = False
return results
async def create_retrieval_strategy(
session: AsyncSession,
config: RetrievalStrategyConfig | None = None,
) -> RetrievalStrategy:
"""创建 RetrievalStrategy 实例。"""
return RetrievalStrategy(session, config)

View File

@ -0,0 +1,192 @@
"""
Rollback Manager.
[AC-AISVC-RES-07] 策略回退与审计管理器
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from app.services.retrieval.strategy.config import RetrievalStrategyConfig, StrategyType
logger = logging.getLogger(__name__)
class RollbackTrigger(str, Enum):
"""回退触发原因。"""
MANUAL = "manual"
ERROR = "error"
PERFORMANCE = "performance"
TIMEOUT = "timeout"
@dataclass
class AuditLog:
"""审计日志记录。"""
timestamp: str
action: str
from_strategy: str
to_strategy: str
trigger: str
reason: str
tenant_id: str | None = None
details: dict[str, Any] = field(default_factory=dict)
@dataclass
class RollbackResult:
"""回退结果。"""
success: bool
previous_strategy: StrategyType
current_strategy: StrategyType
trigger: RollbackTrigger
reason: str
audit_log: AuditLog | None = None
class RollbackManager:
"""
策略回退管理器AC-AISVC-RES-07
职责
1. 策略异常时回退到默认策略
2. 记录审计日志
3. 支持手动触发回退
"""
def __init__(
self,
config: RetrievalStrategyConfig | None = None,
max_audit_logs: int = 1000,
):
self._config = config or RetrievalStrategyConfig()
self._max_audit_logs = max_audit_logs
self._audit_logs: list[AuditLog] = []
self._previous_strategy: StrategyType = StrategyType.DEFAULT
def rollback(
self,
trigger: RollbackTrigger,
reason: str,
tenant_id: str | None = None,
details: dict[str, Any] | None = None,
) -> RollbackResult:
"""执行策略回退。【AC-AISVC-RES-07】"""
previous = self._config.active_strategy
current = StrategyType.DEFAULT
if previous == StrategyType.DEFAULT:
return RollbackResult(
success=False,
previous_strategy=previous,
current_strategy=current,
trigger=trigger,
reason="Already on default strategy",
)
self._previous_strategy = previous
self._config.active_strategy = current
audit_log = AuditLog(
timestamp=datetime.utcnow().isoformat(),
action="rollback",
from_strategy=previous.value,
to_strategy=current.value,
trigger=trigger.value,
reason=reason,
tenant_id=tenant_id,
details=details or {},
)
self._add_audit_log(audit_log)
logger.info(
f"[RollbackManager] Rollback executed: from={previous.value}, "
f"to={current.value}, trigger={trigger.value}"
)
return RollbackResult(
success=True,
previous_strategy=previous,
current_strategy=current,
trigger=trigger,
reason=reason,
audit_log=audit_log,
)
def check_and_rollback(
self, metrics: dict[str, float], tenant_id: str | None = None
) -> RollbackResult | None:
"""检查性能指标并自动回退。【AC-AISVC-RES-08】"""
thresholds = self._config.performance_thresholds
latency = metrics.get("latency_ms", 0)
if latency > thresholds.get("max_latency_ms", 2000):
return self.rollback(
trigger=RollbackTrigger.PERFORMANCE,
reason=f"Latency {latency}ms exceeds threshold",
tenant_id=tenant_id,
)
error_rate = metrics.get("error_rate", 0)
if error_rate > thresholds.get("max_error_rate", 0.05):
return self.rollback(
trigger=RollbackTrigger.ERROR,
reason=f"Error rate {error_rate} exceeds threshold",
tenant_id=tenant_id,
)
return None
def _add_audit_log(self, log: AuditLog) -> None:
self._audit_logs.append(log)
if len(self._audit_logs) > self._max_audit_logs:
self._audit_logs = self._audit_logs[-self._max_audit_logs:]
def record_audit(
self,
action: str,
details: dict[str, Any],
tenant_id: str | None = None,
) -> AuditLog:
"""记录审计日志。"""
audit_log = AuditLog(
timestamp=datetime.utcnow().isoformat(),
action=action,
from_strategy=self._config.active_strategy.value,
to_strategy=self._config.active_strategy.value,
trigger="n/a",
reason=details.get("reason", ""),
tenant_id=tenant_id,
details=details,
)
self._add_audit_log(audit_log)
logger.info(
f"[RollbackManager] Audit recorded: action={action}, "
f"strategy={self._config.active_strategy.value}"
)
return audit_log
def get_audit_logs(self, limit: int = 100) -> list[AuditLog]:
return self._audit_logs[-limit:]
def get_config(self) -> RetrievalStrategyConfig:
return self._config
def update_config(self, config: RetrievalStrategyConfig) -> None:
self._config = config
_rollback_manager: RollbackManager | None = None
def get_rollback_manager() -> RollbackManager:
global _rollback_manager
if _rollback_manager is None:
_rollback_manager = RollbackManager()
return _rollback_manager

View File

@ -0,0 +1,109 @@
"""
Strategy Router.
[AC-AISVC-RES-01~03] 策略路由器
"""
import logging
from dataclasses import dataclass
from typing import Any
from app.services.retrieval.strategy.config import (
RetrievalStrategyConfig,
StrategyType,
)
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline, get_default_pipeline
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline, get_enhanced_pipeline
from app.services.retrieval.strategy.pipeline_base import BasePipeline
logger = logging.getLogger(__name__)
@dataclass
class RoutingDecision:
"""路由决策结果。"""
strategy: StrategyType
pipeline: BasePipeline
reason: str
grayscale_hit: bool = False
class StrategyRouter:
"""
策略路由器AC-AISVC-RES-01~03
职责
1. 根据配置选择默认策略或增强策略
2. 支持灰度发布percentage/allowlist
3. 不影响正在运行的默认策略请求
"""
def __init__(
self,
config: RetrievalStrategyConfig | None = None,
default_pipeline: DefaultPipeline | None = None,
enhanced_pipeline: EnhancedPipeline | None = None,
):
self._config = config or RetrievalStrategyConfig()
self._default_pipeline = default_pipeline
self._enhanced_pipeline = enhanced_pipeline
async def route(
self, tenant_id: str, user_id: str | None = None
) -> RoutingDecision:
"""路由到合适的策略。【AC-AISVC-RES-01~03】"""
if self._config.active_strategy == StrategyType.ENHANCED:
pipeline = await self._get_enhanced_pipeline()
return RoutingDecision(
strategy=StrategyType.ENHANCED,
pipeline=pipeline,
reason="active_strategy=enhanced",
)
if self._config.grayscale.should_use_enhanced(tenant_id, user_id):
pipeline = await self._get_enhanced_pipeline()
return RoutingDecision(
strategy=StrategyType.ENHANCED,
pipeline=pipeline,
reason="grayscale_hit",
grayscale_hit=True,
)
pipeline = await self._get_default_pipeline()
return RoutingDecision(
strategy=StrategyType.DEFAULT,
pipeline=pipeline,
reason="default_strategy",
)
async def _get_default_pipeline(self) -> DefaultPipeline:
if self._default_pipeline is None:
self._default_pipeline = await get_default_pipeline()
return self._default_pipeline
async def _get_enhanced_pipeline(self) -> EnhancedPipeline:
if self._enhanced_pipeline is None:
self._enhanced_pipeline = await get_enhanced_pipeline()
return self._enhanced_pipeline
def get_config(self) -> RetrievalStrategyConfig:
return self._config
def update_config(self, config: RetrievalStrategyConfig) -> None:
self._config = config
logger.info(f"[StrategyRouter] Config updated: strategy={config.active_strategy.value}")
_strategy_router: StrategyRouter | None = None
def get_strategy_router() -> StrategyRouter:
global _strategy_router
if _strategy_router is None:
_strategy_router = StrategyRouter()
return _strategy_router
def set_strategy_router(router: StrategyRouter) -> None:
"""Set the global strategy router instance."""
global _strategy_router
_strategy_router = router

View File

@ -0,0 +1,233 @@
"""
Retrieval Strategy Integration for Dialogue Flow.
[AC-AISVC-RES-01~15] Integrates StrategyRouter and ModeRouter into dialogue pipeline.
Usage:
from app.services.retrieval.strategy_integration import RetrievalStrategyIntegration
integration = RetrievalStrategyIntegration()
result = await integration.execute(ctx)
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
StrategyContext,
StrategyResult,
)
from app.services.retrieval.strategy_router import (
StrategyRouter,
get_strategy_router,
)
from app.services.retrieval.mode_router import (
ModeRouter,
ModeRouteResult,
get_mode_router,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class RetrievalStrategyResult:
"""Combined result from strategy and mode routing."""
retrieval_result: "RetrievalResult | None"
final_answer: str | None
strategy: StrategyType
mode: RagRuntimeMode
should_fallback: bool = False
fallback_reason: str | None = None
mode_route_result: ModeRouteResult | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
duration_ms: int = 0
class RetrievalStrategyIntegration:
"""
[AC-AISVC-RES-01~15] Integration layer for retrieval strategy.
Combines StrategyRouter and ModeRouter to provide a unified interface
for the dialogue pipeline.
Flow:
1. StrategyRouter selects default or enhanced strategy
2. ModeRouter selects direct, react, or auto mode
3. Execute retrieval with selected strategy and mode
4. Handle fallback scenarios
"""
def __init__(
self,
config: RoutingConfig | None = None,
strategy_router: StrategyRouter | None = None,
mode_router: ModeRouter | None = None,
):
self._config = config or RoutingConfig()
self._strategy_router = strategy_router or get_strategy_router()
self._mode_router = mode_router or get_mode_router()
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update all routing configurations.
"""
self._config = new_config
self._strategy_router.update_config(new_config)
self._mode_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-15] RetrievalStrategyIntegration config updated: "
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
)
async def execute(
self,
ctx: StrategyContext,
) -> RetrievalStrategyResult:
"""
Execute retrieval with strategy and mode routing.
Args:
ctx: Strategy context with tenant, query, metadata, etc.
Returns:
RetrievalStrategyResult with retrieval results and diagnostics
"""
start_time = time.time()
strategy_result = self._strategy_router.route(ctx)
mode_result = self._mode_router.route(ctx)
logger.info(
f"[AC-AISVC-RES-01~15] Strategy routing: "
f"strategy={strategy_result.strategy.value}, mode={mode_result.mode.value}, "
f"tenant={ctx.tenant_id}, query_len={len(ctx.query)}"
)
retrieval_result = None
final_answer = None
should_fallback = False
fallback_reason = None
try:
if mode_result.mode == RagRuntimeMode.DIRECT:
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
if answer is not None:
final_answer = answer
should_fallback = mode_result.should_fallback_to_react
fallback_reason = mode_result.fallback_reason
elif mode_result.mode == RagRuntimeMode.REACT:
answer, retrieval_result, react_ctx = await self._mode_router.execute_react(ctx)
final_answer = answer
else:
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
if answer is not None:
final_answer = answer
should_fallback = mode_result.should_fallback_to_react
fallback_reason = mode_result.fallback_reason
except Exception as e:
logger.error(
f"[AC-AISVC-RES-07] Retrieval strategy execution failed: {e}"
)
if strategy_result.strategy == StrategyType.ENHANCED:
self._strategy_router.rollback(
reason=str(e),
tenant_id=ctx.tenant_id,
)
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
retrieval_result = await retriever.retrieve(retrieval_ctx)
should_fallback = True
fallback_reason = str(e)
else:
raise
duration_ms = int((time.time() - start_time) * 1000)
return RetrievalStrategyResult(
retrieval_result=retrieval_result,
final_answer=final_answer,
strategy=strategy_result.strategy,
mode=mode_result.mode,
should_fallback=should_fallback,
fallback_reason=fallback_reason,
mode_route_result=mode_result,
diagnostics={
"strategy_diagnostics": strategy_result.diagnostics,
"mode_diagnostics": mode_result.diagnostics,
"duration_ms": duration_ms,
},
duration_ms=duration_ms,
)
def get_current_strategy(self) -> StrategyType:
"""Get current active strategy."""
return self._strategy_router.current_strategy
def get_rollback_records(self, limit: int = 10) -> list[dict[str, Any]]:
"""Get recent rollback records."""
records = self._strategy_router.get_rollback_records(limit)
return [
{
"timestamp": r.timestamp,
"from_strategy": r.from_strategy.value,
"to_strategy": r.to_strategy.value,
"reason": r.reason,
"tenant_id": r.tenant_id,
}
for r in records
]
def validate_config(self) -> tuple[bool, list[str]]:
"""Validate current configuration."""
return self._config.validate()
_integration: RetrievalStrategyIntegration | None = None
def get_retrieval_strategy_integration() -> RetrievalStrategyIntegration:
"""Get or create RetrievalStrategyIntegration singleton."""
global _integration
if _integration is None:
_integration = RetrievalStrategyIntegration()
return _integration
def reset_retrieval_strategy_integration() -> None:
"""Reset RetrievalStrategyIntegration singleton (for testing)."""
global _integration
_integration = None

View File

@ -0,0 +1,403 @@
"""
Strategy Router for Retrieval and Embedding.
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy.
Key Features:
- Default strategy preserves existing online logic
- Enhanced strategy is configurable and can be rolled back
- Supports grayscale release (percentage/allowlist)
- Supports rollback on error or performance degradation
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
StrategyContext,
StrategyResult,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class RollbackRecord:
"""Record for strategy rollback event."""
timestamp: float
from_strategy: StrategyType
to_strategy: StrategyType
reason: str
tenant_id: str | None = None
request_id: str | None = None
class RollbackManager:
"""
[AC-AISVC-RES-07] Manages strategy rollback and audit logging.
"""
def __init__(self, max_records: int = 100):
self._records: list[RollbackRecord] = []
self._max_records = max_records
def record_rollback(
self,
from_strategy: StrategyType,
to_strategy: StrategyType,
reason: str,
tenant_id: str | None = None,
request_id: str | None = None,
) -> None:
"""Record a rollback event."""
record = RollbackRecord(
timestamp=time.time(),
from_strategy=from_strategy,
to_strategy=to_strategy,
reason=reason,
tenant_id=tenant_id,
request_id=request_id,
)
self._records.append(record)
if len(self._records) > self._max_records:
self._records = self._records[-self._max_records:]
logger.info(
f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, "
f"reason={reason}, tenant={tenant_id}"
)
def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]:
"""Get recent rollback records."""
return self._records[-limit:]
def get_rollback_count(self, since_timestamp: float | None = None) -> int:
"""Get count of rollbacks, optionally since a timestamp."""
if since_timestamp is None:
return len(self._records)
return sum(1 for r in self._records if r.timestamp >= since_timestamp)
class DefaultPipeline:
"""
[AC-AISVC-RES-01] Default pipeline that preserves existing online logic.
This pipeline uses the existing OptimizedRetriever without any new features.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute default retrieval strategy.
Uses existing OptimizedRetriever with current configuration.
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class EnhancedPipeline:
"""
[AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features.
Features:
- Document preprocessing (cleaning/normalization)
- Structured chunking (markdown/tables/FAQ)
- Metadata generation and mounting
- Embedding strategy (document/query prefix + Matryoshka)
- Metadata inference and filtering (hard/soft filter)
- Retrieval strategy (Dense + Keyword/Hybrid + RRF)
- Optional reranking
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute enhanced retrieval strategy.
Uses OptimizedRetriever with enhanced configuration.
"""
from app.services.retrieval.optimized_retriever import OptimizedRetriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = OptimizedRetriever(
two_stage_enabled=True,
hybrid_enabled=True,
)
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class StrategyRouter:
"""
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03]
Strategy router for retrieval and embedding.
Decision Flow:
1. Check if enhanced strategy is enabled via configuration
2. Check grayscale rules (percentage/allowlist)
3. Route to appropriate pipeline (default/enhanced)
4. Handle rollback on error or performance degradation
Constraints:
- Default strategy MUST preserve existing online logic
- Enhanced strategy MUST be configurable and rollback-able
"""
def __init__(
self,
config: RoutingConfig | None = None,
rollback_manager: RollbackManager | None = None,
):
self._config = config or RoutingConfig()
self._rollback_manager = rollback_manager or RollbackManager()
self._default_pipeline = DefaultPipeline()
self._enhanced_pipeline = EnhancedPipeline(self._config)
self._current_strategy = StrategyType.DEFAULT
self._strategy_enabled = True
@property
def current_strategy(self) -> StrategyType:
"""Get current active strategy."""
return self._current_strategy
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
Args:
new_config: New configuration to apply
"""
old_strategy = self._config.strategy
self._config = new_config
logger.info(
f"[AC-AISVC-RES-15] Routing config updated: "
f"strategy={new_config.strategy.value}, "
f"mode={new_config.rag_runtime_mode.value}, "
f"grayscale={new_config.grayscale_percentage:.2%}"
)
if old_strategy != new_config.strategy:
logger.info(
f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}"
)
def route(
self,
ctx: StrategyContext,
) -> StrategyResult:
"""
[AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy.
Args:
ctx: Strategy context with tenant, query, metadata, etc.
Returns:
StrategyResult with selected strategy and mode
"""
if not self._strategy_enabled:
logger.info("[AC-AISVC-RES-07] Strategy disabled, using default")
return StrategyResult(
strategy=StrategyType.DEFAULT,
mode=self._config.rag_runtime_mode,
should_fallback=False,
diagnostics={"reason": "strategy_disabled"},
)
use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id)
if use_enhanced:
self._current_strategy = StrategyType.ENHANCED
logger.info(
f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}"
)
else:
self._current_strategy = StrategyType.DEFAULT
logger.info(
f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}"
)
return StrategyResult(
strategy=self._current_strategy,
mode=self._config.rag_runtime_mode,
diagnostics={
"grayscale_percentage": self._config.grayscale_percentage,
"in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False,
},
)
async def execute(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult", StrategyResult]:
"""
Execute retrieval with strategy routing.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult, StrategyResult)
"""
start_time = time.time()
result = self.route(ctx)
try:
if result.strategy == StrategyType.ENHANCED:
retrieval_result = await self._enhanced_pipeline.execute(ctx)
else:
retrieval_result = await self._default_pipeline.execute(ctx)
duration_ms = int((time.time() - start_time) * 1000)
if duration_ms > self._config.performance_budget_ms:
degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms
if degradation > self._config.performance_degradation_threshold:
logger.warning(
f"[AC-AISVC-RES-08] Performance degradation detected: "
f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, "
f"degradation={degradation:.2%}"
)
return retrieval_result, result
except Exception as e:
logger.error(
f"[AC-AISVC-RES-07] Strategy execution failed: {e}, "
f"strategy={result.strategy.value}"
)
if result.strategy == StrategyType.ENHANCED:
self._rollback_manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=str(e),
tenant_id=ctx.tenant_id,
)
logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy")
retrieval_result = await self._default_pipeline.execute(ctx)
return retrieval_result, StrategyResult(
strategy=StrategyType.DEFAULT,
mode=result.mode,
should_fallback=True,
fallback_reason=str(e),
diagnostics=result.diagnostics,
)
raise
def rollback(
self,
reason: str,
tenant_id: str | None = None,
request_id: str | None = None,
) -> None:
"""
[AC-AISVC-RES-07] Force rollback to default strategy.
Args:
reason: Reason for rollback
tenant_id: Optional tenant ID for audit
request_id: Optional request ID for audit
"""
if self._current_strategy == StrategyType.ENHANCED:
self._rollback_manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=reason,
tenant_id=tenant_id,
request_id=request_id,
)
self._current_strategy = StrategyType.DEFAULT
self._config.strategy = StrategyType.DEFAULT
logger.info(
f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}"
)
def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]:
"""Get recent rollback records."""
return self._rollback_manager.get_recent_rollbacks(limit)
def validate_config(self) -> tuple[bool, list[str]]:
"""
[AC-AISVC-RES-06] Validate current configuration.
Returns:
Tuple of (is_valid, list of error messages)
"""
return self._config.validate()
_strategy_router: StrategyRouter | None = None
def get_strategy_router() -> StrategyRouter:
"""Get or create StrategyRouter singleton."""
global _strategy_router
if _strategy_router is None:
_strategy_router = StrategyRouter()
return _strategy_router
def reset_strategy_router() -> None:
"""Reset StrategyRouter singleton (for testing)."""
global _strategy_router
_strategy_router = None

Some files were not shown because too many files have changed in this diff Show More