Compare commits
20 Commits
95365298f2
...
1b71c29ddb
| Author | SHA1 | Date |
|---|---|---|
|
|
1b71c29ddb | |
|
|
4de51bb18a | |
|
|
c0688c2b13 | |
|
|
c628181623 | |
|
|
2476da8957 | |
|
|
7027097513 | |
|
|
9f28498b97 | |
|
|
42f55ac4d1 | |
|
|
3b354ba041 | |
|
|
812af6c7a1 | |
|
|
f4ca25b0d8 | |
|
|
fe883cfff0 | |
|
|
759eafb490 | |
|
|
9769f7ccf0 | |
|
|
248a225436 | |
|
|
d78b72ca93 | |
|
|
66902cd7c1 | |
|
|
0dfc60935d | |
|
|
3969322d34 | |
|
|
b832f372d1 |
|
|
@ -13,59 +13,66 @@
|
|||
<span class="logo-text">AI Robot</span>
|
||||
</div>
|
||||
<nav class="main-nav">
|
||||
<router-link to="/dashboard" class="nav-item" :class="{ active: isActive('/dashboard') }">
|
||||
<el-icon><Odometer /></el-icon>
|
||||
<span>控制台</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/knowledge-bases" class="nav-item" :class="{ active: isActive('/admin/knowledge-bases') }">
|
||||
<el-icon><FolderOpened /></el-icon>
|
||||
<span>知识库</span>
|
||||
</router-link>
|
||||
<router-link to="/rag-lab" class="nav-item" :class="{ active: isActive('/rag-lab') }">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
<span>RAG 实验室</span>
|
||||
</router-link>
|
||||
<router-link to="/monitoring" class="nav-item" :class="{ active: isActive('/monitoring') }">
|
||||
<el-icon><Monitor /></el-icon>
|
||||
<span>会话监控</span>
|
||||
</router-link>
|
||||
<div class="nav-divider"></div>
|
||||
<router-link to="/admin/embedding" class="nav-item" :class="{ active: isActive('/admin/embedding') }">
|
||||
<el-icon><Connection /></el-icon>
|
||||
<span>嵌入模型</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/llm" class="nav-item" :class="{ active: isActive('/admin/llm') }">
|
||||
<el-icon><ChatDotSquare /></el-icon>
|
||||
<span>LLM 配置</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/prompt-templates" class="nav-item" :class="{ active: isActive('/admin/prompt-templates') }">
|
||||
<el-icon><Document /></el-icon>
|
||||
<span>Prompt 模板</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/intent-rules" class="nav-item" :class="{ active: isActive('/admin/intent-rules') }">
|
||||
<el-icon><Aim /></el-icon>
|
||||
<span>意图规则</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/script-flows" class="nav-item" :class="{ active: isActive('/admin/script-flows') }">
|
||||
<el-icon><Share /></el-icon>
|
||||
<span>话术流程</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/guardrails" class="nav-item" :class="{ active: isActive('/admin/guardrails') }">
|
||||
<el-icon><Warning /></el-icon>
|
||||
<span>输出护栏</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/mid-platform-playground" class="nav-item" :class="{ active: isActive('/admin/mid-platform-playground') }">
|
||||
<el-icon><ChatLineRound /></el-icon>
|
||||
<span>中台联调</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/metadata-schemas" class="nav-item" :class="{ active: isActive('/admin/metadata-schemas') }">
|
||||
<el-icon><Setting /></el-icon>
|
||||
<span>元数据配置</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/slot-definitions" class="nav-item" :class="{ active: isActive('/admin/slot-definitions') }">
|
||||
<el-icon><Grid /></el-icon>
|
||||
<span>槽位定义</span>
|
||||
</router-link>
|
||||
<div class="nav-row">
|
||||
<router-link to="/dashboard" class="nav-item" :class="{ active: isActive('/dashboard') }">
|
||||
<el-icon><Odometer /></el-icon>
|
||||
<span>控制台</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/knowledge-bases" class="nav-item" :class="{ active: isActive('/admin/knowledge-bases') }">
|
||||
<el-icon><FolderOpened /></el-icon>
|
||||
<span>知识库</span>
|
||||
</router-link>
|
||||
<router-link to="/rag-lab" class="nav-item" :class="{ active: isActive('/rag-lab') }">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
<span>RAG 实验室</span>
|
||||
</router-link>
|
||||
<router-link to="/monitoring" class="nav-item" :class="{ active: isActive('/monitoring') }">
|
||||
<el-icon><Monitor /></el-icon>
|
||||
<span>会话监控</span>
|
||||
</router-link>
|
||||
</div>
|
||||
<div class="nav-row">
|
||||
<router-link to="/admin/embedding" class="nav-item" :class="{ active: isActive('/admin/embedding') }">
|
||||
<el-icon><Connection /></el-icon>
|
||||
<span>嵌入模型</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/llm" class="nav-item" :class="{ active: isActive('/admin/llm') }">
|
||||
<el-icon><ChatDotSquare /></el-icon>
|
||||
<span>LLM 配置</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/prompt-templates" class="nav-item" :class="{ active: isActive('/admin/prompt-templates') }">
|
||||
<el-icon><Document /></el-icon>
|
||||
<span>Prompt 模板</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/intent-rules" class="nav-item" :class="{ active: isActive('/admin/intent-rules') }">
|
||||
<el-icon><Aim /></el-icon>
|
||||
<span>意图规则</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/script-flows" class="nav-item" :class="{ active: isActive('/admin/script-flows') }">
|
||||
<el-icon><Share /></el-icon>
|
||||
<span>话术流程</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/guardrails" class="nav-item" :class="{ active: isActive('/admin/guardrails') }">
|
||||
<el-icon><Warning /></el-icon>
|
||||
<span>输出护栏</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/mid-platform-playground" class="nav-item" :class="{ active: isActive('/admin/mid-platform-playground') }">
|
||||
<el-icon><ChatLineRound /></el-icon>
|
||||
<span>中台联调</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/metadata-schemas" class="nav-item" :class="{ active: isActive('/admin/metadata-schemas') }">
|
||||
<el-icon><Setting /></el-icon>
|
||||
<span>元数据配置</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/slot-definitions" class="nav-item" :class="{ active: isActive('/admin/slot-definitions') }">
|
||||
<el-icon><Grid /></el-icon>
|
||||
<span>槽位定义</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/scene-slot-bundles" class="nav-item" :class="{ active: isActive('/admin/scene-slot-bundles') }">
|
||||
<el-icon><Collection /></el-icon>
|
||||
<span>场景槽位包</span>
|
||||
</router-link>
|
||||
</div>
|
||||
</nav>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
|
|
@ -98,7 +105,7 @@ import { ref, onMounted } from 'vue'
|
|||
import { useRoute } from 'vue-router'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
import { getTenantList, type Tenant } from '@/api/tenant'
|
||||
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare, Document, Aim, Share, Warning, Setting, ChatLineRound, Grid } from '@element-plus/icons-vue'
|
||||
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare, Document, Aim, Share, Warning, Setting, ChatLineRound, Grid, Collection } from '@element-plus/icons-vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
|
||||
const route = useRoute()
|
||||
|
|
@ -168,7 +175,8 @@ onMounted(() => {
|
|||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0 24px;
|
||||
height: 60px;
|
||||
height: auto;
|
||||
min-height: 60px;
|
||||
background-color: var(--bg-secondary, #FFFFFF);
|
||||
border-bottom: 1px solid var(--border-color, #E2E8F0);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.04);
|
||||
|
|
@ -208,6 +216,12 @@ onMounted(() => {
|
|||
}
|
||||
|
||||
.main-nav {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.nav-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
|
|
|
|||
|
|
@ -36,3 +36,19 @@ export function deleteDocument(docId: string) {
|
|||
method: 'delete'
|
||||
})
|
||||
}
|
||||
|
||||
export function batchUploadDocuments(data: FormData) {
|
||||
return request({
|
||||
url: '/admin/kb/documents/batch-upload',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export function jsonBatchUpload(kbId: string, data: FormData) {
|
||||
return request({
|
||||
url: `/admin/kb/${kbId}/documents/json-batch`,
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import type {
|
|||
|
||||
export const metadataSchemaApi = {
|
||||
list: (status?: 'draft' | 'active' | 'deprecated', fieldRole?: FieldRole) =>
|
||||
request<MetadataFieldListResponse>({
|
||||
request<MetadataFieldDefinition[]>({
|
||||
method: 'GET',
|
||||
url: '/admin/metadata-schemas',
|
||||
params: {
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ export interface ToolCallTrace {
|
|||
error_code?: string
|
||||
args_digest?: string
|
||||
result_digest?: string
|
||||
arguments?: Record<string, unknown>
|
||||
result?: unknown
|
||||
}
|
||||
|
||||
export interface DialogueResponse {
|
||||
|
|
@ -219,3 +221,16 @@ export function sendSharedMessage(shareToken: string, userMessage: string): Prom
|
|||
data: { user_message: userMessage }
|
||||
})
|
||||
}
|
||||
|
||||
export interface CancelFlowResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
session_id: string
|
||||
}
|
||||
|
||||
export function cancelActiveFlow(sessionId: string): Promise<CancelFlowResponse> {
|
||||
return request({
|
||||
url: `/mid/sessions/${sessionId}/cancel-flow`,
|
||||
method: 'post'
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
import request from '@/utils/request'
|
||||
import type {
|
||||
SceneSlotBundle,
|
||||
SceneSlotBundleWithDetails,
|
||||
SceneSlotBundleCreateRequest,
|
||||
SceneSlotBundleUpdateRequest,
|
||||
} from '@/types/scene-slot-bundle'
|
||||
|
||||
export const sceneSlotBundleApi = {
|
||||
list: (status?: string) =>
|
||||
request<SceneSlotBundle[]>({
|
||||
method: 'GET',
|
||||
url: '/admin/scene-slot-bundles',
|
||||
params: status ? { status } : {},
|
||||
}),
|
||||
|
||||
get: (id: string) =>
|
||||
request<SceneSlotBundleWithDetails>({
|
||||
method: 'GET',
|
||||
url: `/admin/scene-slot-bundles/${id}`,
|
||||
}),
|
||||
|
||||
getBySceneKey: (sceneKey: string) =>
|
||||
request<SceneSlotBundle>({
|
||||
method: 'GET',
|
||||
url: `/admin/scene-slot-bundles/by-scene/${sceneKey}`,
|
||||
}),
|
||||
|
||||
create: (data: SceneSlotBundleCreateRequest) =>
|
||||
request<SceneSlotBundle>({
|
||||
method: 'POST',
|
||||
url: '/admin/scene-slot-bundles',
|
||||
data,
|
||||
}),
|
||||
|
||||
update: (id: string, data: SceneSlotBundleUpdateRequest) =>
|
||||
request<SceneSlotBundle>({
|
||||
method: 'PUT',
|
||||
url: `/admin/scene-slot-bundles/${id}`,
|
||||
data,
|
||||
}),
|
||||
|
||||
delete: (id: string) =>
|
||||
request<void>({
|
||||
method: 'DELETE',
|
||||
url: `/admin/scene-slot-bundles/${id}`,
|
||||
}),
|
||||
}
|
||||
|
|
@ -99,17 +99,19 @@ const deprecatedFields = computed(() => {
|
|||
})
|
||||
|
||||
const visibleFields = computed(() => {
|
||||
// Show all fields except deprecated in edit mode, so users can change draft to active
|
||||
if (props.isNewObject) {
|
||||
return activeFields.value
|
||||
}
|
||||
return allFields.value.filter(f => f.status !== 'draft')
|
||||
return allFields.value.filter(f => f.status !== 'deprecated')
|
||||
})
|
||||
|
||||
const loadFields = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await metadataSchemaApi.getByScope(props.scope, props.showDeprecated)
|
||||
allFields.value = res.items || []
|
||||
// Handle both formats: direct array or { items: array }
|
||||
allFields.value = Array.isArray(res) ? res : (res.items || [])
|
||||
emit('fields-loaded', allFields.value)
|
||||
|
||||
applyDefaults()
|
||||
|
|
|
|||
|
|
@ -17,12 +17,6 @@ const routes: Array<RouteRecordRaw> = [
|
|||
component: () => import('@/views/dashboard/index.vue'),
|
||||
meta: { title: '控制台' }
|
||||
},
|
||||
{
|
||||
path: '/kb',
|
||||
name: 'KBManagement',
|
||||
component: () => import('@/views/kb/index.vue'),
|
||||
meta: { title: '知识库管理' }
|
||||
},
|
||||
{
|
||||
path: '/rag-lab',
|
||||
name: 'RagLab',
|
||||
|
|
@ -71,6 +65,12 @@ const routes: Array<RouteRecordRaw> = [
|
|||
component: () => import('@/views/admin/slot-definition/index.vue'),
|
||||
meta: { title: '槽位定义管理' }
|
||||
},
|
||||
{
|
||||
path: '/admin/scene-slot-bundles',
|
||||
name: 'SceneSlotBundle',
|
||||
component: () => import('@/views/admin/scene-slot-bundle/index.vue'),
|
||||
meta: { title: '场景槽位包管理' }
|
||||
},
|
||||
{
|
||||
path: '/admin/intent-rules',
|
||||
name: 'IntentRule',
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ export interface MetadataFieldDefinition {
|
|||
scope: MetadataScope[]
|
||||
is_filterable: boolean
|
||||
is_rank_feature: boolean
|
||||
usage_description?: string
|
||||
field_roles: FieldRole[]
|
||||
status: MetadataFieldStatus
|
||||
created_at?: string
|
||||
|
|
@ -31,6 +32,7 @@ export interface MetadataFieldCreateRequest {
|
|||
scope: MetadataScope[]
|
||||
is_filterable?: boolean
|
||||
is_rank_feature?: boolean
|
||||
usage_description?: string
|
||||
field_roles?: FieldRole[]
|
||||
status: MetadataFieldStatus
|
||||
}
|
||||
|
|
@ -43,6 +45,7 @@ export interface MetadataFieldUpdateRequest {
|
|||
scope?: MetadataScope[]
|
||||
is_filterable?: boolean
|
||||
is_rank_feature?: boolean
|
||||
usage_description?: string
|
||||
field_roles?: FieldRole[]
|
||||
status?: MetadataFieldStatus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Scene Slot Bundle Types.
|
||||
* [AC-SCENE-SLOT-01] 场景-槽位映射配置类型定义
|
||||
*/
|
||||
|
||||
export type SceneSlotBundleStatus = 'draft' | 'active' | 'deprecated'
|
||||
export type AskBackOrder = 'priority' | 'required_first' | 'parallel'
|
||||
|
||||
export interface SceneSlotBundle {
|
||||
id: string
|
||||
tenant_id: string
|
||||
scene_key: string
|
||||
scene_name: string
|
||||
description: string | null
|
||||
required_slots: string[]
|
||||
optional_slots: string[]
|
||||
slot_priority: string[] | null
|
||||
completion_threshold: number
|
||||
ask_back_order: AskBackOrder
|
||||
status: SceneSlotBundleStatus
|
||||
version: number
|
||||
created_at: string | null
|
||||
updated_at: string | null
|
||||
}
|
||||
|
||||
export interface SlotDetail {
|
||||
slot_key: string
|
||||
type: string
|
||||
required: boolean
|
||||
ask_back_prompt: string | null
|
||||
linked_field_id: string | null
|
||||
}
|
||||
|
||||
export interface SceneSlotBundleWithDetails extends SceneSlotBundle {
|
||||
required_slot_details: SlotDetail[]
|
||||
optional_slot_details: SlotDetail[]
|
||||
}
|
||||
|
||||
export interface SceneSlotBundleCreateRequest {
|
||||
scene_key: string
|
||||
scene_name: string
|
||||
description?: string
|
||||
required_slots?: string[]
|
||||
optional_slots?: string[]
|
||||
slot_priority?: string[]
|
||||
completion_threshold?: number
|
||||
ask_back_order?: AskBackOrder
|
||||
status?: SceneSlotBundleStatus
|
||||
}
|
||||
|
||||
export interface SceneSlotBundleUpdateRequest {
|
||||
scene_name?: string
|
||||
description?: string
|
||||
required_slots?: string[]
|
||||
optional_slots?: string[]
|
||||
slot_priority?: string[]
|
||||
completion_threshold?: number
|
||||
ask_back_order?: AskBackOrder
|
||||
status?: SceneSlotBundleStatus
|
||||
}
|
||||
|
||||
export const SCENE_SLOT_BUNDLE_STATUS_OPTIONS = [
|
||||
{ value: 'draft', label: '草稿', description: '配置中,未启用' },
|
||||
{ value: 'active', label: '已启用', description: '运行时可用' },
|
||||
{ value: 'deprecated', label: '已废弃', description: '不再使用' },
|
||||
]
|
||||
|
||||
export const ASK_BACK_ORDER_OPTIONS = [
|
||||
{ value: 'priority', label: '按优先级', description: '按 slot_priority 定义的顺序追问' },
|
||||
{ value: 'required_first', label: '必填优先', description: '优先追问必填槽位' },
|
||||
{ value: 'parallel', label: '并行追问', description: '一次追问多个缺失槽位' },
|
||||
]
|
||||
|
||||
export function getStatusLabel(status: SceneSlotBundleStatus): string {
|
||||
return SCENE_SLOT_BUNDLE_STATUS_OPTIONS.find(o => o.value === status)?.label || status
|
||||
}
|
||||
|
||||
export function getAskBackOrderLabel(order: AskBackOrder): string {
|
||||
return ASK_BACK_ORDER_OPTIONS.find(o => o.value === order)?.label || order
|
||||
}
|
||||
|
|
@ -39,6 +39,10 @@ export interface FlowStep {
|
|||
intent_description?: string
|
||||
script_constraints?: string[]
|
||||
expected_variables?: string[]
|
||||
allowed_kb_ids?: string[]
|
||||
preferred_kb_ids?: string[]
|
||||
kb_query_hint?: string
|
||||
max_kb_calls_per_step?: number
|
||||
}
|
||||
|
||||
export interface NextCondition {
|
||||
|
|
|
|||
|
|
@ -2,13 +2,31 @@ export type SlotType = 'string' | 'number' | 'boolean' | 'enum' | 'array_enum'
|
|||
export type ExtractStrategy = 'rule' | 'llm' | 'user_input'
|
||||
export type SlotSource = 'user_confirmed' | 'rule_extracted' | 'llm_inferred' | 'default'
|
||||
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 提取失败类型
|
||||
*/
|
||||
export type ExtractFailureType =
|
||||
| 'EXTRACT_EMPTY'
|
||||
| 'EXTRACT_PARSE_FAIL'
|
||||
| 'EXTRACT_VALIDATION_FAIL'
|
||||
| 'EXTRACT_RUNTIME_ERROR'
|
||||
|
||||
export interface SlotDefinition {
|
||||
id: string
|
||||
tenant_id: string
|
||||
slot_key: string
|
||||
display_name?: string
|
||||
description?: string
|
||||
type: SlotType
|
||||
required: boolean
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 兼容字段:单提取策略,已废弃
|
||||
*/
|
||||
extract_strategy?: ExtractStrategy
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 提取策略链:有序数组,按顺序执行直到成功
|
||||
*/
|
||||
extract_strategies?: ExtractStrategy[]
|
||||
validation_rule?: string
|
||||
ask_back_prompt?: string
|
||||
default_value?: string | number | boolean | string[]
|
||||
|
|
@ -29,8 +47,17 @@ export interface LinkedField {
|
|||
export interface SlotDefinitionCreateRequest {
|
||||
tenant_id?: string
|
||||
slot_key: string
|
||||
display_name?: string
|
||||
description?: string
|
||||
type: SlotType
|
||||
required: boolean
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 提取策略链
|
||||
*/
|
||||
extract_strategies?: ExtractStrategy[]
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 兼容字段
|
||||
*/
|
||||
extract_strategy?: ExtractStrategy
|
||||
validation_rule?: string
|
||||
ask_back_prompt?: string
|
||||
|
|
@ -39,8 +66,17 @@ export interface SlotDefinitionCreateRequest {
|
|||
}
|
||||
|
||||
export interface SlotDefinitionUpdateRequest {
|
||||
display_name?: string
|
||||
description?: string
|
||||
type?: SlotType
|
||||
required?: boolean
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 提取策略链
|
||||
*/
|
||||
extract_strategies?: ExtractStrategy[]
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 兼容字段
|
||||
*/
|
||||
extract_strategy?: ExtractStrategy
|
||||
validation_rule?: string
|
||||
ask_back_prompt?: string
|
||||
|
|
@ -56,6 +92,31 @@ export interface RuntimeSlotValue {
|
|||
updated_at?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 策略链执行步骤结果
|
||||
*/
|
||||
export interface StrategyStepResult {
|
||||
strategy: ExtractStrategy
|
||||
success: boolean
|
||||
value?: string | number | boolean | string[]
|
||||
failure_type?: ExtractFailureType
|
||||
failure_reason?: string
|
||||
execution_time_ms: number
|
||||
}
|
||||
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 策略链执行结果
|
||||
*/
|
||||
export interface StrategyChainResult {
|
||||
slot_key: string
|
||||
success: boolean
|
||||
final_value?: string | number | boolean | string[]
|
||||
final_strategy?: ExtractStrategy
|
||||
steps: StrategyStepResult[]
|
||||
total_execution_time_ms: number
|
||||
ask_back_prompt?: string
|
||||
}
|
||||
|
||||
export const SLOT_TYPE_OPTIONS = [
|
||||
{ value: 'string', label: '文本' },
|
||||
{ value: 'number', label: '数字' },
|
||||
|
|
@ -69,3 +130,42 @@ export const EXTRACT_STRATEGY_OPTIONS = [
|
|||
{ value: 'llm', label: 'LLM 推断', description: '通过大语言模型推断槽位值' },
|
||||
{ value: 'user_input', label: '用户输入', description: '通过追问提示语让用户主动输入' }
|
||||
]
|
||||
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 提取策略链验证
|
||||
*/
|
||||
export function validateExtractStrategies(strategies: ExtractStrategy[]): { valid: boolean; error?: string } {
|
||||
if (!strategies || strategies.length === 0) {
|
||||
return { valid: true } // 空数组视为有效
|
||||
}
|
||||
|
||||
const validStrategies = new Set(['rule', 'llm', 'user_input'])
|
||||
|
||||
// 检查重复
|
||||
const uniqueStrategies = new Set(strategies)
|
||||
if (uniqueStrategies.size !== strategies.length) {
|
||||
return { valid: false, error: '提取策略链中不允许重复的策略' }
|
||||
}
|
||||
|
||||
// 检查有效性
|
||||
const invalid = strategies.filter(s => !validStrategies.has(s))
|
||||
if (invalid.length > 0) {
|
||||
return { valid: false, error: `无效的提取策略: ${invalid.join(', ')}` }
|
||||
}
|
||||
|
||||
return { valid: true }
|
||||
}
|
||||
|
||||
/**
|
||||
* [AC-MRS-07-UPGRADE] 获取有效的策略链
|
||||
* 优先使用 extract_strategies,如果不存在则兼容读取 extract_strategy
|
||||
*/
|
||||
export function getEffectiveStrategies(slot: SlotDefinition): ExtractStrategy[] {
|
||||
if (slot.extract_strategies && slot.extract_strategies.length > 0) {
|
||||
return slot.extract_strategies
|
||||
}
|
||||
if (slot.extract_strategy) {
|
||||
return [slot.extract_strategy]
|
||||
}
|
||||
return []
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import { useTenantStore } from '@/stores/tenant'
|
|||
|
||||
const service = axios.create({
|
||||
baseURL: import.meta.env.VITE_APP_BASE_API || '/api',
|
||||
timeout: 60000
|
||||
timeout: 180000
|
||||
})
|
||||
|
||||
service.interceptors.request.use(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,14 @@
|
|||
<el-icon><Upload /></el-icon>
|
||||
上传文档
|
||||
</el-button>
|
||||
<el-button type="success" @click="handleBatchUploadClick">
|
||||
<el-icon><Upload /></el-icon>
|
||||
批量上传
|
||||
</el-button>
|
||||
<el-button type="warning" @click="handleJsonUploadClick">
|
||||
<el-icon><Upload /></el-icon>
|
||||
JSON上传
|
||||
</el-button>
|
||||
<input
|
||||
ref="fileInputRef"
|
||||
type="file"
|
||||
|
|
@ -12,6 +20,20 @@
|
|||
style="display: none"
|
||||
@change="handleFileSelect"
|
||||
/>
|
||||
<input
|
||||
ref="batchFileInputRef"
|
||||
type="file"
|
||||
accept=".zip"
|
||||
style="display: none"
|
||||
@change="handleBatchFileSelect"
|
||||
/>
|
||||
<input
|
||||
ref="jsonFileInputRef"
|
||||
type="file"
|
||||
accept=".json,.jsonl,.txt"
|
||||
style="display: none"
|
||||
@change="handleJsonFileSelect"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<el-table :data="documents" v-loading="loading" stripe>
|
||||
|
|
@ -165,6 +187,8 @@ const pagination = ref({
|
|||
})
|
||||
|
||||
const fileInputRef = ref<HTMLInputElement>()
|
||||
const batchFileInputRef = ref<HTMLInputElement>()
|
||||
const jsonFileInputRef = ref<HTMLInputElement>()
|
||||
const uploadDialogVisible = ref(false)
|
||||
const editDialogVisible = ref(false)
|
||||
const uploading = ref(false)
|
||||
|
|
@ -261,6 +285,14 @@ const handleUploadClick = () => {
|
|||
fileInputRef.value?.click()
|
||||
}
|
||||
|
||||
const handleBatchUploadClick = () => {
|
||||
batchFileInputRef.value?.click()
|
||||
}
|
||||
|
||||
const handleJsonUploadClick = () => {
|
||||
jsonFileInputRef.value?.click()
|
||||
}
|
||||
|
||||
const handleFileSelect = (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
|
|
@ -289,6 +321,140 @@ const handleFileSelect = (event: Event) => {
|
|||
}
|
||||
}
|
||||
|
||||
const handleBatchFileSelect = async (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
if (!file.name.toLowerCase().endsWith('.zip')) {
|
||||
ElMessage.error('请上传 .zip 格式的压缩包')
|
||||
if (batchFileInputRef.value) {
|
||||
batchFileInputRef.value.value = ''
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const maxSize = 100 * 1024 * 1024
|
||||
if (file.size > maxSize) {
|
||||
ElMessage.error('压缩包大小不能超过 100MB')
|
||||
if (batchFileInputRef.value) {
|
||||
batchFileInputRef.value.value = ''
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
loading.value = true
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
formData.append('kb_id', props.kbId)
|
||||
|
||||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
|
||||
const response = await fetch(`${baseUrl}/admin/kb/documents/batch-upload`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'X-Tenant-Id': tenantStore.currentTenantId,
|
||||
'X-API-Key': apiKey
|
||||
},
|
||||
body: formData
|
||||
})
|
||||
|
||||
const result = await response.json()
|
||||
if (result.success) {
|
||||
const { total, succeeded, failed, results } = result
|
||||
ElMessage.success(`批量上传完成!成功: ${succeeded}, 失败: ${failed}, 总计: ${total}`)
|
||||
|
||||
// 收集所有成功的任务ID,稍后串行轮询
|
||||
const jobIds: string[] = []
|
||||
for (const item of results) {
|
||||
if (item.status === 'created' && item.jobId) {
|
||||
jobIds.push(item.jobId)
|
||||
}
|
||||
}
|
||||
|
||||
// 串行轮询所有任务,避免并发请求触发限流
|
||||
if (jobIds.length > 0) {
|
||||
pollJobStatusSequential(jobIds)
|
||||
}
|
||||
|
||||
emit('upload-success')
|
||||
loadDocuments()
|
||||
} else {
|
||||
ElMessage.error(result.message || '批量上传失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('批量上传失败,请重试')
|
||||
console.error('Batch upload error:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
if (batchFileInputRef.value) {
|
||||
batchFileInputRef.value.value = ''
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleJsonFileSelect = async (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
const fileName = file.name.toLowerCase()
|
||||
if (!fileName.endsWith('.json') && !fileName.endsWith('.jsonl') && !fileName.endsWith('.txt')) {
|
||||
ElMessage.error('请上传 .json、.jsonl 或 .txt 格式的文件')
|
||||
if (jsonFileInputRef.value) {
|
||||
jsonFileInputRef.value.value = ''
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
loading.value = true
|
||||
ElMessage.info('正在上传 JSON 文件...')
|
||||
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
|
||||
const response = await fetch(`${baseUrl}/admin/kb/${props.kbId}/documents/json-batch?tenant_id=${tenantStore.currentTenantId}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'X-Tenant-Id': tenantStore.currentTenantId,
|
||||
'X-API-Key': apiKey
|
||||
},
|
||||
body: formData
|
||||
})
|
||||
|
||||
const result = await response.json()
|
||||
if (result.success) {
|
||||
ElMessage.success(`JSON 批量上传成功!成功: ${result.succeeded}, 失败: ${result.failed}`)
|
||||
|
||||
if (result.valid_metadata_fields) {
|
||||
console.log('有效的元数据字段:', result.valid_metadata_fields)
|
||||
}
|
||||
|
||||
if (result.failed > 0) {
|
||||
const failedItems = result.results.filter((r: any) => !r.success)
|
||||
console.warn('失败的项目:', failedItems)
|
||||
}
|
||||
|
||||
emit('upload-success')
|
||||
loadDocuments()
|
||||
} else {
|
||||
ElMessage.error(result.message || 'JSON 批量上传失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('JSON 批量上传失败,请重试')
|
||||
console.error('JSON batch upload error:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
if (jsonFileInputRef.value) {
|
||||
jsonFileInputRef.value.value = ''
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpload = async () => {
|
||||
if (!selectedFile.value) {
|
||||
ElMessage.warning('请选择文件')
|
||||
|
|
@ -310,11 +476,13 @@ const handleUpload = async () => {
|
|||
formData.append('kb_id', props.kbId)
|
||||
formData.append('metadata', JSON.stringify(uploadForm.value.metadata))
|
||||
|
||||
const baseUrl = import.meta.env.VITE_API_BASE_URL || ''
|
||||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
|
||||
const response = await fetch(`${baseUrl}/admin/kb/documents`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'X-Tenant-Id': tenantStore.currentTenantId
|
||||
'X-Tenant-Id': tenantStore.currentTenantId,
|
||||
'X-API-Key': apiKey
|
||||
},
|
||||
body: formData
|
||||
})
|
||||
|
|
@ -356,12 +524,14 @@ const handleSaveMetadata = async () => {
|
|||
|
||||
saving.value = true
|
||||
try {
|
||||
const baseUrl = import.meta.env.VITE_API_BASE_URL || ''
|
||||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const apiKey = import.meta.env.VITE_APP_API_KEY || ''
|
||||
const response = await fetch(`${baseUrl}/admin/kb/documents/${currentEditDoc.value.docId}/metadata`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Tenant-Id': tenantStore.currentTenantId
|
||||
'X-Tenant-Id': tenantStore.currentTenantId,
|
||||
'X-API-Key': apiKey
|
||||
},
|
||||
body: JSON.stringify({ metadata: editForm.value.metadata })
|
||||
})
|
||||
|
|
@ -398,14 +568,64 @@ const pollJobStatus = async (jobId: string) => {
|
|||
ElMessage.error(`文档处理失败: ${job.errorMsg || '未知错误'}`)
|
||||
loadDocuments()
|
||||
} else {
|
||||
setTimeout(poll, 3000)
|
||||
setTimeout(poll, 10000) // 10秒轮询一次,避免触发限流
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error.status === 429) {
|
||||
console.warn('请求过于频繁,稍后重试')
|
||||
setTimeout(poll, 15000) // 被限流后等待15秒
|
||||
} else {
|
||||
console.error('轮询任务状态失败', error)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('轮询任务状态失败', error)
|
||||
}
|
||||
}
|
||||
|
||||
setTimeout(poll, 2000)
|
||||
setTimeout(poll, 5000) // 首次轮询等待5秒
|
||||
}
|
||||
|
||||
// 串行轮询多个任务,避免并发触发限流
|
||||
const pollJobStatusSequential = async (jobIds: string[]) => {
|
||||
const pendingJobs = new Set(jobIds)
|
||||
const maxPolls = 60
|
||||
let pollCount = 0
|
||||
|
||||
const pollAll = async () => {
|
||||
if (pollCount >= maxPolls || pendingJobs.size === 0) return
|
||||
pollCount++
|
||||
|
||||
const completedJobs: string[] = []
|
||||
|
||||
for (const jobId of pendingJobs) {
|
||||
try {
|
||||
const job: IndexJob = await getIndexJob(jobId)
|
||||
if (job.status === 'completed') {
|
||||
completedJobs.push(jobId)
|
||||
} else if (job.status === 'failed') {
|
||||
completedJobs.push(jobId)
|
||||
console.error(`文档处理失败: ${job.errorMsg || '未知错误'}`)
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error.status === 429) {
|
||||
console.warn('请求过于频繁,稍后重试')
|
||||
break // 被限流时停止本轮轮询,等待下次
|
||||
} else {
|
||||
console.error('轮询任务状态失败', error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除已完成的任务
|
||||
completedJobs.forEach(jobId => pendingJobs.delete(jobId))
|
||||
|
||||
if (pendingJobs.size > 0) {
|
||||
setTimeout(pollAll, 10000) // 10秒后再次轮询
|
||||
} else {
|
||||
ElMessage.success('所有文档处理完成')
|
||||
loadDocuments()
|
||||
}
|
||||
}
|
||||
|
||||
setTimeout(pollAll, 5000) // 首次轮询等待5秒
|
||||
}
|
||||
|
||||
const handleDelete = async (row: DocumentWithMetadata) => {
|
||||
|
|
@ -438,6 +658,8 @@ onMounted(() => {
|
|||
|
||||
.list-header {
|
||||
margin-bottom: 16px;
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.pagination-wrapper {
|
||||
|
|
|
|||
|
|
@ -211,6 +211,14 @@
|
|||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<el-form-item label="用途说明">
|
||||
<el-input
|
||||
v-model="formData.usage_description"
|
||||
type="textarea"
|
||||
:rows="2"
|
||||
placeholder="描述该字段的业务用途和使用场景"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="默认值">
|
||||
<el-input v-model="formData.default_value" placeholder="可选默认值" />
|
||||
</el-form-item>
|
||||
|
|
@ -308,6 +316,7 @@ const formData = reactive({
|
|||
scope: [] as MetadataScope[],
|
||||
is_filterable: true,
|
||||
is_rank_feature: false,
|
||||
usage_description: '',
|
||||
field_roles: [] as FieldRole[],
|
||||
status: 'draft' as MetadataFieldStatus
|
||||
})
|
||||
|
|
@ -354,8 +363,11 @@ const getRoleLabel = (role: FieldRole) => {
|
|||
const fetchFields = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await metadataSchemaApi.list(filterStatus.value || undefined, filterRole.value || undefined)
|
||||
fields.value = res.items || []
|
||||
const res: any = await metadataSchemaApi.list(filterStatus.value || undefined, filterRole.value || undefined)
|
||||
// 兼容多种后端返回格式:[] / {items: []} / {schemas: []} / {data: []}
|
||||
fields.value = Array.isArray(res)
|
||||
? res
|
||||
: (res?.items || res?.schemas || res?.data || [])
|
||||
} catch (error: any) {
|
||||
ElMessage.error(error.response?.data?.message || '获取元数据字段失败')
|
||||
} finally {
|
||||
|
|
@ -376,6 +388,7 @@ const handleCreate = () => {
|
|||
scope: [],
|
||||
is_filterable: true,
|
||||
is_rank_feature: false,
|
||||
usage_description: '',
|
||||
field_roles: [],
|
||||
status: 'draft'
|
||||
})
|
||||
|
|
@ -395,6 +408,7 @@ const handleEdit = (field: MetadataFieldDefinition) => {
|
|||
scope: [...field.scope],
|
||||
is_filterable: field.is_filterable,
|
||||
is_rank_feature: field.is_rank_feature,
|
||||
usage_description: field.usage_description || '',
|
||||
field_roles: field.field_roles || [],
|
||||
status: field.status
|
||||
})
|
||||
|
|
@ -472,6 +486,7 @@ const handleSubmit = async () => {
|
|||
scope: formData.scope,
|
||||
is_filterable: formData.is_filterable,
|
||||
is_rank_feature: formData.is_rank_feature,
|
||||
usage_description: formData.usage_description || undefined,
|
||||
field_roles: formData.field_roles,
|
||||
status: formData.status
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@
|
|||
<el-form-item>
|
||||
<el-button :loading="switchingMode" @click="handleSwitchMode">应用模式</el-button>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-button type="warning" :loading="cancellingFlow" @click="handleCancelFlow">取消流程</el-button>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-button type="primary" @click="handleNewSession">新建会话</el-button>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
|
||||
|
|
@ -98,12 +104,34 @@
|
|||
|
||||
<div class="json-panel">
|
||||
<div class="json-title">tool_calls</div>
|
||||
<pre>{{ JSON.stringify(lastTrace.tool_calls || [], null, 2) }}</pre>
|
||||
<div class="json-content">
|
||||
<el-button
|
||||
class="copy-btn"
|
||||
size="small"
|
||||
text
|
||||
type="primary"
|
||||
@click="copyJson(lastTrace.tool_calls || [], 'tool_calls')"
|
||||
>
|
||||
复制
|
||||
</el-button>
|
||||
<pre>{{ JSON.stringify(lastTrace.tool_calls || [], null, 2) }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="json-panel">
|
||||
<div class="json-title">metrics_snapshot</div>
|
||||
<pre>{{ JSON.stringify(lastTrace.metrics_snapshot || {}, null, 2) }}</pre>
|
||||
<div class="json-content">
|
||||
<el-button
|
||||
class="copy-btn"
|
||||
size="small"
|
||||
text
|
||||
type="primary"
|
||||
@click="copyJson(lastTrace.metrics_snapshot || {}, 'metrics_snapshot')"
|
||||
>
|
||||
复制
|
||||
</el-button>
|
||||
<pre>{{ JSON.stringify(lastTrace.metrics_snapshot || {}, null, 2) }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<el-empty v-else description="暂无 Trace" :image-size="90" />
|
||||
|
|
@ -180,6 +208,7 @@ import {
|
|||
createPublicShareToken,
|
||||
listShares,
|
||||
deleteShare,
|
||||
cancelActiveFlow,
|
||||
type DialogueMessage,
|
||||
type DialogueResponse,
|
||||
type SessionMode,
|
||||
|
|
@ -205,6 +234,7 @@ const rollbackToLegacy = ref(false)
|
|||
const sending = ref(false)
|
||||
const switchingMode = ref(false)
|
||||
const reporting = ref(false)
|
||||
const cancellingFlow = ref(false)
|
||||
|
||||
const userInput = ref('')
|
||||
const conversation = ref<ChatItem[]>([])
|
||||
|
|
@ -221,11 +251,11 @@ const shareForm = ref({
|
|||
|
||||
const now = () => new Date().toLocaleTimeString()
|
||||
|
||||
const tagType = (role: ChatRole) => {
|
||||
const tagType = (role: ChatRole): 'primary' | 'success' | 'info' | 'warning' | 'danger' => {
|
||||
if (role === 'user') return 'info'
|
||||
if (role === 'assistant') return 'success'
|
||||
if (role === 'human') return 'warning'
|
||||
return ''
|
||||
return 'primary'
|
||||
}
|
||||
|
||||
const toHistory = (): DialogueMessage[] => {
|
||||
|
|
@ -289,6 +319,32 @@ const handleSwitchMode = async () => {
|
|||
}
|
||||
}
|
||||
|
||||
const handleCancelFlow = async () => {
|
||||
cancellingFlow.value = true
|
||||
try {
|
||||
const result = await cancelActiveFlow(sessionId.value)
|
||||
if (result.success) {
|
||||
ElMessage.success(result.message)
|
||||
} else {
|
||||
ElMessage.warning(result.message)
|
||||
}
|
||||
} catch {
|
||||
ElMessage.error('取消流程失败')
|
||||
} finally {
|
||||
cancellingFlow.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleNewSession = () => {
|
||||
const newSessionId = `sess_${Date.now()}`
|
||||
const newUserId = `user_${Date.now()}`
|
||||
sessionId.value = newSessionId
|
||||
userId.value = newUserId
|
||||
conversation.value = []
|
||||
lastTrace.value = null
|
||||
ElMessage.success(`新会话已创建: ${newSessionId}`)
|
||||
}
|
||||
|
||||
const handleReportMessages = async () => {
|
||||
if (!conversation.value.length) {
|
||||
ElMessage.warning('暂无可上报消息')
|
||||
|
|
@ -370,6 +426,12 @@ const copyShareUrl = (url: string) => {
|
|||
ElMessage.success('链接已复制到剪贴板')
|
||||
}
|
||||
|
||||
const copyJson = (data: unknown, name: string) => {
|
||||
const jsonStr = JSON.stringify(data, null, 2)
|
||||
navigator.clipboard.writeText(jsonStr)
|
||||
ElMessage.success(`${name} 已复制到剪贴板`)
|
||||
}
|
||||
|
||||
const formatShareExpires = (expiresAt: string) => {
|
||||
const expires = new Date(expiresAt)
|
||||
return expires.toLocaleDateString()
|
||||
|
|
@ -462,6 +524,23 @@ onMounted(() => {
|
|||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.json-content {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.json-content .copy-btn {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
z-index: 10;
|
||||
background: var(--el-bg-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.json-content .copy-btn:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
pre {
|
||||
background: var(--el-fill-color-lighter);
|
||||
padding: 8px;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,660 @@
|
|||
<template>
|
||||
<div class="scene-slot-bundle-page">
|
||||
<div class="page-header">
|
||||
<div class="header-content">
|
||||
<div class="title-section">
|
||||
<h1 class="page-title">场景槽位包管理</h1>
|
||||
<p class="page-desc">配置场景与槽位的映射关系,定义每个场景需要采集的槽位集合。[AC-SCENE-SLOT-01]</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<el-select v-model="filterStatus" placeholder="按状态筛选" clearable style="width: 140px;">
|
||||
<el-option
|
||||
v-for="opt in SCENE_SLOT_BUNDLE_STATUS_OPTIONS"
|
||||
:key="opt.value"
|
||||
:label="opt.label"
|
||||
:value="opt.value"
|
||||
/>
|
||||
</el-select>
|
||||
<el-button type="primary" @click="handleCreate">
|
||||
<el-icon><Plus /></el-icon>
|
||||
新建场景槽位包
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<el-card shadow="hover" class="bundle-card" v-loading="loading">
|
||||
<el-table :data="bundles" stripe style="width: 100%">
|
||||
<el-table-column prop="scene_key" label="场景标识" min-width="140">
|
||||
<template #default="{ row }">
|
||||
<code class="scene-key">{{ row.scene_key }}</code>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="scene_name" label="场景名称" min-width="120" />
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="getStatusTagType(row.status)" size="small">
|
||||
{{ getStatusLabel(row.status) }}
|
||||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="required_slots" label="必填槽位" min-width="160">
|
||||
<template #default="{ row }">
|
||||
<div class="slot-tags">
|
||||
<el-tag
|
||||
v-for="slotKey in row.required_slots.slice(0, 3)"
|
||||
:key="slotKey"
|
||||
size="small"
|
||||
type="danger"
|
||||
class="slot-tag"
|
||||
>
|
||||
{{ slotKey }}
|
||||
</el-tag>
|
||||
<el-tag v-if="row.required_slots.length > 3" size="small" type="info">
|
||||
+{{ row.required_slots.length - 3 }}
|
||||
</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="optional_slots" label="可选槽位" min-width="140">
|
||||
<template #default="{ row }">
|
||||
<div class="slot-tags" v-if="row.optional_slots.length > 0">
|
||||
<el-tag
|
||||
v-for="slotKey in row.optional_slots.slice(0, 3)"
|
||||
:key="slotKey"
|
||||
size="small"
|
||||
type="info"
|
||||
class="slot-tag"
|
||||
>
|
||||
{{ slotKey }}
|
||||
</el-tag>
|
||||
<el-tag v-if="row.optional_slots.length > 3" size="small" type="info">
|
||||
+{{ row.optional_slots.length - 3 }}
|
||||
</el-tag>
|
||||
</div>
|
||||
<span v-else class="no-value">-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="completion_threshold" label="完成阈值" width="100">
|
||||
<template #default="{ row }">
|
||||
<span>{{ (row.completion_threshold * 100).toFixed(0) }}%</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="ask_back_order" label="追问策略" width="100">
|
||||
<template #default="{ row }">
|
||||
<span>{{ getAskBackOrderLabel(row.ask_back_order) }}</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column label="操作" width="150" fixed="right">
|
||||
<template #default="{ row }">
|
||||
<el-button type="primary" link size="small" @click="handleEdit(row)">
|
||||
<el-icon><Edit /></el-icon>
|
||||
编辑
|
||||
</el-button>
|
||||
<el-button type="danger" link size="small" @click="handleDelete(row)">
|
||||
<el-icon><Delete /></el-icon>
|
||||
删除
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<el-empty v-if="!loading && bundles.length === 0" description="暂无场景槽位包" />
|
||||
</el-card>
|
||||
|
||||
<el-dialog
|
||||
v-model="dialogVisible"
|
||||
:title="isEdit ? '编辑场景槽位包' : '新建场景槽位包'"
|
||||
width="800px"
|
||||
:close-on-click-modal="false"
|
||||
destroy-on-close
|
||||
>
|
||||
<el-form :model="formData" :rules="formRules" ref="formRef" label-width="100px">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-form-item label="场景标识" prop="scene_key">
|
||||
<el-input
|
||||
v-model="formData.scene_key"
|
||||
placeholder="如:open_consult, refund_apply"
|
||||
:disabled="isEdit"
|
||||
/>
|
||||
<div class="field-hint">唯一标识,创建后不可修改</div>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-form-item label="场景名称" prop="scene_name">
|
||||
<el-input v-model="formData.scene_name" placeholder="如:开放咨询、退款申请" />
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-form-item label="场景描述" prop="description">
|
||||
<el-input
|
||||
v-model="formData.description"
|
||||
type="textarea"
|
||||
:rows="2"
|
||||
placeholder="描述该场景的业务背景和用途"
|
||||
/>
|
||||
</el-form-item>
|
||||
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-form-item label="状态" prop="status">
|
||||
<el-select v-model="formData.status" style="width: 100%;">
|
||||
<el-option
|
||||
v-for="opt in SCENE_SLOT_BUNDLE_STATUS_OPTIONS"
|
||||
:key="opt.value"
|
||||
:label="opt.label"
|
||||
:value="opt.value"
|
||||
>
|
||||
<div class="status-option">
|
||||
<span>{{ opt.label }}</span>
|
||||
<span class="status-desc">{{ opt.description }}</span>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-form-item label="完成阈值" prop="completion_threshold">
|
||||
<el-slider
|
||||
v-model="formData.completion_threshold"
|
||||
:min="0"
|
||||
:max="1"
|
||||
:step="0.1"
|
||||
:format-tooltip="(val: number) => `${(val * 100).toFixed(0)}%`"
|
||||
/>
|
||||
<div class="field-hint">必填槽位填充比例达到此值视为完成</div>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-form-item label="必填槽位" prop="required_slots">
|
||||
<el-select
|
||||
v-model="formData.required_slots"
|
||||
multiple
|
||||
filterable
|
||||
style="width: 100%;"
|
||||
placeholder="选择必填槽位"
|
||||
>
|
||||
<el-option
|
||||
v-for="slot in availableSlots"
|
||||
:key="slot.id"
|
||||
:label="`${slot.slot_key} (${slot.ask_back_prompt || '无追问提示'})`"
|
||||
:value="slot.slot_key"
|
||||
:disabled="formData.optional_slots.includes(slot.slot_key)"
|
||||
>
|
||||
<div class="slot-option">
|
||||
<span class="slot-key-label">{{ slot.slot_key }}</span>
|
||||
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
<div class="field-hint">缺失必填槽位时会触发追问,不会直接进行 KB 检索</div>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item label="可选槽位" prop="optional_slots">
|
||||
<el-select
|
||||
v-model="formData.optional_slots"
|
||||
multiple
|
||||
filterable
|
||||
style="width: 100%;"
|
||||
placeholder="选择可选槽位"
|
||||
>
|
||||
<el-option
|
||||
v-for="slot in availableSlots"
|
||||
:key="slot.id"
|
||||
:label="slot.slot_key"
|
||||
:value="slot.slot_key"
|
||||
:disabled="formData.required_slots.includes(slot.slot_key)"
|
||||
>
|
||||
<div class="slot-option">
|
||||
<span class="slot-key-label">{{ slot.slot_key }}</span>
|
||||
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
<div class="field-hint">可选槽位用于增强检索效果,缺失时不阻塞流程</div>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item label="槽位优先级" prop="slot_priority">
|
||||
<div class="priority-editor">
|
||||
<div class="priority-header">
|
||||
<span class="priority-hint">拖拽调整槽位采集和追问的优先级顺序</span>
|
||||
<el-button
|
||||
v-if="formData.slot_priority && formData.slot_priority.length > 0"
|
||||
type="danger"
|
||||
link
|
||||
size="small"
|
||||
@click="formData.slot_priority = null"
|
||||
>
|
||||
清空
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
<div v-if="allSlotsForPriority.length > 0" class="priority-list-wrapper">
|
||||
<draggable
|
||||
v-model="allSlotsForPriority"
|
||||
item-key="slot_key"
|
||||
handle=".drag-handle"
|
||||
class="priority-list"
|
||||
ghost-class="ghost"
|
||||
>
|
||||
<template #item="{ element, index }">
|
||||
<div class="priority-item" :class="{ 'in-priority': isInPriority(element.slot_key) }">
|
||||
<el-icon class="drag-handle"><Rank /></el-icon>
|
||||
<span class="priority-order">{{ index + 1 }}</span>
|
||||
<el-tag
|
||||
size="small"
|
||||
:type="formData.required_slots.includes(element.slot_key) ? 'danger' : 'info'"
|
||||
>
|
||||
{{ element.slot_key }}
|
||||
</el-tag>
|
||||
<el-tag size="small" type="warning" v-if="formData.required_slots.includes(element.slot_key)">
|
||||
必填
|
||||
</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
</draggable>
|
||||
</div>
|
||||
<div v-else class="priority-empty">
|
||||
<el-text type="info">请先添加必填或可选槽位</el-text>
|
||||
</div>
|
||||
</div>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item label="追问策略" prop="ask_back_order">
|
||||
<el-radio-group v-model="formData.ask_back_order">
|
||||
<el-radio-button
|
||||
v-for="opt in ASK_BACK_ORDER_OPTIONS"
|
||||
:key="opt.value"
|
||||
:value="opt.value"
|
||||
>
|
||||
{{ opt.label }}
|
||||
</el-radio-button>
|
||||
</el-radio-group>
|
||||
<div class="field-hint">
|
||||
{{ ASK_BACK_ORDER_OPTIONS.find(o => o.value === formData.ask_back_order)?.description }}
|
||||
</div>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
<template #footer>
|
||||
<el-button @click="dialogVisible = false">取消</el-button>
|
||||
<el-button type="primary" :loading="submitting" @click="handleSubmit">
|
||||
{{ isEdit ? '保存' : '创建' }}
|
||||
</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onMounted, watch, computed } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Plus, Edit, Delete, Rank } from '@element-plus/icons-vue'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import draggable from 'vuedraggable'
|
||||
import { sceneSlotBundleApi } from '@/api/scene-slot-bundle'
|
||||
import { slotDefinitionApi } from '@/api/slot-definition'
|
||||
import {
|
||||
SCENE_SLOT_BUNDLE_STATUS_OPTIONS,
|
||||
ASK_BACK_ORDER_OPTIONS,
|
||||
type SceneSlotBundle,
|
||||
type SceneSlotBundleCreateRequest,
|
||||
type SceneSlotBundleUpdateRequest,
|
||||
type SceneSlotBundleStatus,
|
||||
type AskBackOrder,
|
||||
getStatusLabel,
|
||||
getAskBackOrderLabel,
|
||||
} from '@/types/scene-slot-bundle'
|
||||
import type { SlotDefinition } from '@/types/slot-definition'
|
||||
|
||||
const loading = ref(false)
|
||||
const bundles = ref<SceneSlotBundle[]>([])
|
||||
const availableSlots = ref<SlotDefinition[]>([])
|
||||
const filterStatus = ref<SceneSlotBundleStatus | ''>('')
|
||||
const dialogVisible = ref(false)
|
||||
const isEdit = ref(false)
|
||||
const submitting = ref(false)
|
||||
const formRef = ref<FormInstance>()
|
||||
|
||||
const formData = reactive({
|
||||
id: '',
|
||||
scene_key: '',
|
||||
scene_name: '',
|
||||
description: '',
|
||||
required_slots: [] as string[],
|
||||
optional_slots: [] as string[],
|
||||
slot_priority: null as string[] | null,
|
||||
completion_threshold: 1.0,
|
||||
ask_back_order: 'priority' as AskBackOrder,
|
||||
status: 'draft' as SceneSlotBundleStatus,
|
||||
})
|
||||
|
||||
const formRules: FormRules = {
|
||||
scene_key: [
|
||||
{ required: true, message: '请输入场景标识', trigger: 'blur' },
|
||||
{ pattern: /^[a-z][a-z0-9_]*$/, message: '以小写字母开头,仅允许小写字母、数字、下划线', trigger: 'blur' }
|
||||
],
|
||||
scene_name: [{ required: true, message: '请输入场景名称', trigger: 'blur' }],
|
||||
status: [{ required: true, message: '请选择状态', trigger: 'change' }],
|
||||
}
|
||||
|
||||
const getStatusTagType = (status: SceneSlotBundleStatus): 'success' | 'warning' | 'info' => {
|
||||
const typeMap: Record<SceneSlotBundleStatus, 'success' | 'warning' | 'info'> = {
|
||||
'active': 'success',
|
||||
'draft': 'warning',
|
||||
'deprecated': 'info',
|
||||
}
|
||||
return typeMap[status] || 'info'
|
||||
}
|
||||
|
||||
const allSlotsForPriority = computed({
|
||||
get: () => {
|
||||
const allKeys = [...formData.required_slots, ...formData.optional_slots]
|
||||
const priority = formData.slot_priority || allKeys
|
||||
|
||||
const orderedSlots = priority
|
||||
.filter(key => allKeys.includes(key))
|
||||
.map(key => availableSlots.value.find(s => s.slot_key === key) || { slot_key: key })
|
||||
|
||||
const remainingKeys = allKeys.filter(key => !priority.includes(key))
|
||||
const remainingSlots = remainingKeys
|
||||
.map(key => availableSlots.value.find(s => s.slot_key === key) || { slot_key: key })
|
||||
|
||||
return [...orderedSlots, ...remainingSlots]
|
||||
},
|
||||
set: (value: { slot_key: string }[]) => {
|
||||
formData.slot_priority = value.map(s => s.slot_key)
|
||||
}
|
||||
})
|
||||
|
||||
const isInPriority = (slotKey: string): boolean => {
|
||||
if (!formData.slot_priority) return true
|
||||
return formData.slot_priority.includes(slotKey)
|
||||
}
|
||||
|
||||
const fetchBundles = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await sceneSlotBundleApi.list(filterStatus.value || undefined)
|
||||
bundles.value = res || []
|
||||
} catch (error: any) {
|
||||
ElMessage.error(error.response?.data?.message || '获取场景槽位包列表失败')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const fetchAvailableSlots = async () => {
|
||||
try {
|
||||
const res = await slotDefinitionApi.list()
|
||||
availableSlots.value = res || []
|
||||
} catch (error: any) {
|
||||
console.error('获取槽位定义失败', error)
|
||||
}
|
||||
}
|
||||
|
||||
const handleCreate = () => {
|
||||
isEdit.value = false
|
||||
Object.assign(formData, {
|
||||
id: '',
|
||||
scene_key: '',
|
||||
scene_name: '',
|
||||
description: '',
|
||||
required_slots: [],
|
||||
optional_slots: [],
|
||||
slot_priority: null,
|
||||
completion_threshold: 1.0,
|
||||
ask_back_order: 'priority',
|
||||
status: 'draft',
|
||||
})
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const handleEdit = (bundle: SceneSlotBundle) => {
|
||||
isEdit.value = true
|
||||
Object.assign(formData, {
|
||||
id: bundle.id,
|
||||
scene_key: bundle.scene_key,
|
||||
scene_name: bundle.scene_name,
|
||||
description: bundle.description || '',
|
||||
required_slots: [...bundle.required_slots],
|
||||
optional_slots: [...bundle.optional_slots],
|
||||
slot_priority: bundle.slot_priority ? [...bundle.slot_priority] : null,
|
||||
completion_threshold: bundle.completion_threshold,
|
||||
ask_back_order: bundle.ask_back_order,
|
||||
status: bundle.status,
|
||||
})
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const handleDelete = async (bundle: SceneSlotBundle) => {
|
||||
try {
|
||||
await ElMessageBox.confirm(
|
||||
`确定要删除场景槽位包「${bundle.scene_name}」吗?`,
|
||||
'删除确认',
|
||||
{ type: 'warning' }
|
||||
)
|
||||
await sceneSlotBundleApi.delete(bundle.id)
|
||||
ElMessage.success('删除成功')
|
||||
fetchBundles()
|
||||
} catch (error: any) {
|
||||
if (error !== 'cancel') {
|
||||
ElMessage.error(error.response?.data?.message || '删除失败')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!formRef.value) return
|
||||
|
||||
await formRef.value.validate(async (valid) => {
|
||||
if (!valid) return
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const data: SceneSlotBundleCreateRequest | SceneSlotBundleUpdateRequest = {
|
||||
scene_name: formData.scene_name,
|
||||
description: formData.description || undefined,
|
||||
required_slots: formData.required_slots.length > 0 ? formData.required_slots : undefined,
|
||||
optional_slots: formData.optional_slots.length > 0 ? formData.optional_slots : undefined,
|
||||
slot_priority: formData.slot_priority || undefined,
|
||||
completion_threshold: formData.completion_threshold,
|
||||
ask_back_order: formData.ask_back_order,
|
||||
status: formData.status,
|
||||
}
|
||||
|
||||
if (isEdit.value) {
|
||||
await sceneSlotBundleApi.update(formData.id, data as SceneSlotBundleUpdateRequest)
|
||||
ElMessage.success('更新成功')
|
||||
} else {
|
||||
const createData = data as SceneSlotBundleCreateRequest
|
||||
createData.scene_key = formData.scene_key
|
||||
await sceneSlotBundleApi.create(createData)
|
||||
ElMessage.success('创建成功')
|
||||
}
|
||||
|
||||
dialogVisible.value = false
|
||||
fetchBundles()
|
||||
} catch (error: any) {
|
||||
ElMessage.error(error.response?.data?.message || '操作失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
watch(filterStatus, () => {
|
||||
fetchBundles()
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
fetchBundles()
|
||||
fetchAvailableSlots()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped lang="scss">
|
||||
.scene-slot-bundle-page {
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.header-content {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
.page-title {
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
margin: 0 0 8px 0;
|
||||
color: var(--el-text-color-primary);
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
font-size: 14px;
|
||||
color: var(--el-text-color-secondary);
|
||||
margin: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.bundle-card {
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.scene-key {
|
||||
padding: 2px 6px;
|
||||
background-color: var(--el-fill-color-light);
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.slot-tags {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.slot-tag {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.no-value {
|
||||
color: var(--el-text-color-placeholder);
|
||||
}
|
||||
|
||||
.field-hint {
|
||||
margin-top: 4px;
|
||||
font-size: 12px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.status-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
|
||||
.status-desc {
|
||||
font-size: 12px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
}
|
||||
|
||||
.slot-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
|
||||
.slot-key-label {
|
||||
font-weight: 500;
|
||||
font-family: monospace;
|
||||
}
|
||||
}
|
||||
|
||||
.priority-editor {
|
||||
border: 1px solid var(--el-border-color);
|
||||
border-radius: 4px;
|
||||
padding: 12px;
|
||||
background-color: var(--el-fill-color-light);
|
||||
}
|
||||
|
||||
.priority-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.priority-hint {
|
||||
font-size: 12px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.priority-list-wrapper {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.priority-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.priority-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 12px;
|
||||
background-color: var(--el-bg-color);
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--el-border-color-light);
|
||||
cursor: move;
|
||||
transition: all 0.2s;
|
||||
|
||||
&:hover {
|
||||
border-color: var(--el-color-primary-light-5);
|
||||
}
|
||||
|
||||
&.in-priority {
|
||||
background-color: var(--el-color-primary-light-9);
|
||||
}
|
||||
}
|
||||
|
||||
.drag-handle {
|
||||
cursor: move;
|
||||
color: var(--el-text-color-secondary);
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.priority-order {
|
||||
width: 20px;
|
||||
text-align: center;
|
||||
font-weight: 600;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.priority-empty {
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.ghost {
|
||||
opacity: 0.5;
|
||||
background: var(--el-color-primary-light-8);
|
||||
}
|
||||
</style>
|
||||
|
|
@ -185,11 +185,23 @@
|
|||
v-model="element.expected_variables"
|
||||
multiple
|
||||
filterable
|
||||
allow-create
|
||||
default-first-option
|
||||
placeholder="输入变量名后回车添加"
|
||||
placeholder="选择期望提取的槽位"
|
||||
style="width: 100%"
|
||||
/>
|
||||
>
|
||||
<el-option
|
||||
v-for="slot in availableSlots"
|
||||
:key="slot.id"
|
||||
:label="`${slot.slot_key} (${slot.type})`"
|
||||
:value="slot.slot_key"
|
||||
>
|
||||
<div class="slot-option">
|
||||
<span class="slot-key">{{ slot.slot_key }}</span>
|
||||
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
|
||||
<el-tag size="small" v-if="slot.required" type="danger">必填</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
<div class="field-hint">期望变量必须引用已定义的槽位,步骤完成时会检查这些槽位是否已填充</div>
|
||||
</el-form-item>
|
||||
</template>
|
||||
|
||||
|
|
@ -218,14 +230,101 @@
|
|||
v-model="element.expected_variables"
|
||||
multiple
|
||||
filterable
|
||||
allow-create
|
||||
default-first-option
|
||||
placeholder="输入变量名后回车添加"
|
||||
placeholder="选择期望提取的槽位"
|
||||
style="width: 100%"
|
||||
/>
|
||||
>
|
||||
<el-option
|
||||
v-for="slot in availableSlots"
|
||||
:key="slot.id"
|
||||
:label="`${slot.slot_key} (${slot.type})`"
|
||||
:value="slot.slot_key"
|
||||
>
|
||||
<div class="slot-option">
|
||||
<span class="slot-key">{{ slot.slot_key }}</span>
|
||||
<el-tag size="small" type="info">{{ slot.type }}</el-tag>
|
||||
<el-tag size="small" v-if="slot.required" type="danger">必填</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
<div class="field-hint">期望变量必须引用已定义的槽位,步骤完成时会检查这些槽位是否已填充</div>
|
||||
</el-form-item>
|
||||
</template>
|
||||
|
||||
<el-divider content-position="left">知识库范围</el-divider>
|
||||
|
||||
<div class="kb-binding-section">
|
||||
<el-form-item label="允许的知识库">
|
||||
<el-select
|
||||
v-model="element.allowed_kb_ids"
|
||||
multiple
|
||||
filterable
|
||||
clearable
|
||||
placeholder="选择允许检索的知识库(为空则使用默认策略)"
|
||||
style="width: 100%"
|
||||
>
|
||||
<el-option
|
||||
v-for="kb in availableKnowledgeBases"
|
||||
:key="kb.id"
|
||||
:label="kb.name"
|
||||
:value="kb.id"
|
||||
>
|
||||
<div class="kb-option">
|
||||
<span>{{ kb.name }}</span>
|
||||
<el-tag size="small" :type="getKbTypeTagType(kb.kbType)">{{ getKbTypeLabel(kb.kbType) }}</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
<div class="field-hint">限制该步骤只能从选定的知识库中检索信息,为空则使用默认检索策略</div>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item label="优先知识库">
|
||||
<el-select
|
||||
v-model="element.preferred_kb_ids"
|
||||
multiple
|
||||
filterable
|
||||
clearable
|
||||
placeholder="选择优先检索的知识库"
|
||||
style="width: 100%"
|
||||
>
|
||||
<el-option
|
||||
v-for="kb in availableKnowledgeBases"
|
||||
:key="kb.id"
|
||||
:label="kb.name"
|
||||
:value="kb.id"
|
||||
:disabled="element.allowed_kb_ids && element.allowed_kb_ids.length > 0 && !element.allowed_kb_ids.includes(kb.id)"
|
||||
>
|
||||
<div class="kb-option">
|
||||
<span>{{ kb.name }}</span>
|
||||
<el-tag size="small" :type="getKbTypeTagType(kb.kbType)">{{ getKbTypeLabel(kb.kbType) }}</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
<div class="field-hint">检索时优先搜索这些知识库</div>
|
||||
</el-form-item>
|
||||
|
||||
<el-row :gutter="16">
|
||||
<el-col :span="16">
|
||||
<el-form-item label="检索提示">
|
||||
<el-input
|
||||
v-model="element.kb_query_hint"
|
||||
placeholder="可选:描述本步骤的检索意图,帮助提高检索准确性"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-form-item label="最大检索次数">
|
||||
<el-input-number
|
||||
v-model="element.max_kb_calls_per_step"
|
||||
:min="1"
|
||||
:max="5"
|
||||
placeholder="默认1"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
|
||||
<el-row :gutter="16">
|
||||
<el-col :span="8">
|
||||
<el-form-item label="等待输入">
|
||||
|
|
@ -362,17 +461,24 @@ import {
|
|||
createScriptFlow,
|
||||
updateScriptFlow,
|
||||
deleteScriptFlow,
|
||||
getScriptFlow
|
||||
getScriptFlow,
|
||||
} from '@/api/script-flow'
|
||||
import { slotDefinitionApi } from '@/api/slot-definition'
|
||||
import { listKnowledgeBases } from '@/api/knowledge-base'
|
||||
import { MetadataForm } from '@/components/metadata'
|
||||
import { TIMEOUT_ACTION_OPTIONS, SCRIPT_MODE_OPTIONS } from '@/types/script-flow'
|
||||
import type { ScriptFlow, ScriptFlowDetail, ScriptFlowCreate, ScriptFlowUpdate, FlowStep, ScriptMode } from '@/types/script-flow'
|
||||
import type { SlotDefinition } from '@/types/slot-definition'
|
||||
import type { KnowledgeBase } from '@/types/knowledge-base'
|
||||
import { KB_TYPE_MAP } from '@/types/knowledge-base'
|
||||
import FlowPreview from './components/FlowPreview.vue'
|
||||
import SimulateDialog from './components/SimulateDialog.vue'
|
||||
import ConstraintManager from './components/ConstraintManager.vue'
|
||||
|
||||
const loading = ref(false)
|
||||
const flows = ref<ScriptFlow[]>([])
|
||||
const availableSlots = ref<SlotDefinition[]>([])
|
||||
const availableKnowledgeBases = ref<KnowledgeBase[]>([])
|
||||
const dialogVisible = ref(false)
|
||||
const previewDrawer = ref(false)
|
||||
const simulateDialogVisible = ref(false)
|
||||
|
|
@ -425,16 +531,55 @@ const loadFlows = async () => {
|
|||
}
|
||||
}
|
||||
|
||||
const loadAvailableSlots = async () => {
|
||||
try {
|
||||
const res = await slotDefinitionApi.list()
|
||||
availableSlots.value = res || []
|
||||
} catch (error) {
|
||||
console.error('加载槽位定义失败', error)
|
||||
availableSlots.value = []
|
||||
}
|
||||
}
|
||||
|
||||
const loadAvailableKnowledgeBases = async () => {
|
||||
try {
|
||||
const res = await listKnowledgeBases({ is_enabled: true })
|
||||
availableKnowledgeBases.value = res.data || []
|
||||
} catch (error) {
|
||||
console.error('加载知识库列表失败', error)
|
||||
availableKnowledgeBases.value = []
|
||||
}
|
||||
}
|
||||
|
||||
const getKbTypeLabel = (kbType: string): string => {
|
||||
return KB_TYPE_MAP[kbType]?.label || kbType
|
||||
}
|
||||
|
||||
const getKbTypeTagType = (kbType: string): string => {
|
||||
const typeMap: Record<string, string> = {
|
||||
product: 'primary',
|
||||
faq: 'success',
|
||||
script: 'warning',
|
||||
policy: 'danger',
|
||||
general: 'info'
|
||||
}
|
||||
return typeMap[kbType] || 'info'
|
||||
}
|
||||
|
||||
const handleCreate = () => {
|
||||
isEdit.value = false
|
||||
currentEditId.value = ''
|
||||
formData.value = defaultFormData()
|
||||
loadAvailableSlots()
|
||||
loadAvailableKnowledgeBases()
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const handleEdit = async (row: ScriptFlow) => {
|
||||
isEdit.value = true
|
||||
currentEditId.value = row.id
|
||||
await loadAvailableSlots()
|
||||
await loadAvailableKnowledgeBases()
|
||||
try {
|
||||
const detail = await getScriptFlow(row.id)
|
||||
formData.value = {
|
||||
|
|
@ -444,7 +589,11 @@ const handleEdit = async (row: ScriptFlow) => {
|
|||
...step,
|
||||
script_mode: step.script_mode || 'fixed',
|
||||
script_constraints: step.script_constraints || [],
|
||||
expected_variables: step.expected_variables || []
|
||||
expected_variables: step.expected_variables || [],
|
||||
allowed_kb_ids: step.allowed_kb_ids || [],
|
||||
preferred_kb_ids: step.preferred_kb_ids || [],
|
||||
kb_query_hint: step.kb_query_hint || '',
|
||||
max_kb_calls_per_step: step.max_kb_calls_per_step || null
|
||||
})),
|
||||
is_enabled: detail.is_enabled,
|
||||
metadata: detail.metadata || {}
|
||||
|
|
@ -739,4 +888,35 @@ onMounted(() => {
|
|||
border-radius: 4px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.slot-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.slot-key {
|
||||
font-weight: 500;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
.field-hint {
|
||||
margin-top: 4px;
|
||||
font-size: 12px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.kb-binding-section {
|
||||
padding: 12px;
|
||||
background-color: var(--el-fill-color-lighter);
|
||||
border-radius: 6px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.kb-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -23,7 +23,16 @@
|
|||
<el-table :data="slots" stripe style="width: 100%">
|
||||
<el-table-column prop="slot_key" label="槽位标识" min-width="140">
|
||||
<template #default="{ row }">
|
||||
<code class="slot-key">{{ row.slot_key }}</code>
|
||||
<div class="slot-key-cell">
|
||||
<code class="slot-key">{{ row.slot_key }}</code>
|
||||
<span v-if="row.display_name" class="display-name">{{ row.display_name }}</span>
|
||||
</div>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="description" label="说明" min-width="180">
|
||||
<template #default="{ row }">
|
||||
<span v-if="row.description" class="slot-description">{{ row.description }}</span>
|
||||
<span v-else class="no-value">-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="type" label="类型" width="100">
|
||||
|
|
@ -38,11 +47,20 @@
|
|||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="extract_strategy" label="提取策略" width="120">
|
||||
<!-- [AC-MRS-07-UPGRADE] 策略链展示 -->
|
||||
<el-table-column prop="extract_strategies" label="提取策略链" min-width="180">
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.extract_strategy" size="small">
|
||||
{{ getExtractStrategyLabel(row.extract_strategy) }}
|
||||
</el-tag>
|
||||
<div v-if="getEffectiveStrategies(row).length > 0" class="strategy-chain">
|
||||
<el-tag
|
||||
v-for="(strategy, idx) in getEffectiveStrategies(row)"
|
||||
:key="idx"
|
||||
size="small"
|
||||
:type="getStrategyTagType(strategy)"
|
||||
class="strategy-tag"
|
||||
>
|
||||
{{ idx + 1 }}. {{ getExtractStrategyLabel(strategy) }}
|
||||
</el-tag>
|
||||
</div>
|
||||
<span v-else class="no-value">-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
|
|
@ -80,7 +98,7 @@
|
|||
<el-dialog
|
||||
v-model="dialogVisible"
|
||||
:title="isEdit ? '编辑槽位定义' : '新建槽位定义'"
|
||||
width="650px"
|
||||
width="700px"
|
||||
:close-on-click-modal="false"
|
||||
destroy-on-close
|
||||
>
|
||||
|
|
@ -96,6 +114,25 @@
|
|||
<div class="field-hint">仅允许小写字母、数字、下划线</div>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-form-item label="槽位名称" prop="display_name">
|
||||
<el-input
|
||||
v-model="formData.display_name"
|
||||
placeholder="如:当前年级"
|
||||
/>
|
||||
<div class="field-hint">给运营/教研看的中文名</div>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<el-form-item label="槽位说明" prop="description">
|
||||
<el-input
|
||||
v-model="formData.description"
|
||||
type="textarea"
|
||||
:rows="2"
|
||||
placeholder="解释这个槽位采集什么、用于哪里,如:用于课程分层推荐和知识库过滤"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-form-item label="槽位类型" prop="type">
|
||||
<el-select v-model="formData.type" style="width: 100%;">
|
||||
|
|
@ -108,21 +145,74 @@
|
|||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-form-item label="是否必填" prop="required">
|
||||
<el-switch v-model="formData.required" />
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-form-item label="提取策略" prop="extract_strategy">
|
||||
<el-select v-model="formData.extract_strategy" style="width: 100%;" clearable placeholder="选择提取策略">
|
||||
</el-row>
|
||||
|
||||
<!-- [AC-MRS-07-UPGRADE] 提取策略链配置 -->
|
||||
<el-form-item label="提取策略链" prop="extract_strategies">
|
||||
<div class="strategy-chain-editor">
|
||||
<div class="strategy-chain-header">
|
||||
<span class="chain-hint">按优先级排序,系统将依次尝试直到成功</span>
|
||||
<el-button
|
||||
v-if="formData.extract_strategies.length > 0"
|
||||
type="danger"
|
||||
link
|
||||
size="small"
|
||||
@click="clearStrategies"
|
||||
>
|
||||
清空
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
<div v-if="formData.extract_strategies.length > 0" class="strategy-chain-list">
|
||||
<draggable
|
||||
v-model="formData.extract_strategies"
|
||||
:item-key="(item: ExtractStrategy) => item"
|
||||
handle=".drag-handle"
|
||||
class="strategy-list"
|
||||
>
|
||||
<template #item="{ element, index }">
|
||||
<div class="strategy-item">
|
||||
<el-icon class="drag-handle"><Rank /></el-icon>
|
||||
<span class="strategy-order">{{ index + 1 }}</span>
|
||||
<el-tag size="small" :type="getStrategyTagType(element)">
|
||||
{{ getExtractStrategyLabel(element) }}
|
||||
</el-tag>
|
||||
<el-button
|
||||
type="danger"
|
||||
link
|
||||
size="small"
|
||||
class="remove-btn"
|
||||
@click="removeStrategy(index)"
|
||||
>
|
||||
<el-icon><Close /></el-icon>
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
</draggable>
|
||||
</div>
|
||||
|
||||
<div v-else class="strategy-empty">
|
||||
<el-text type="info">暂无策略,请从下方添加</el-text>
|
||||
</div>
|
||||
|
||||
<div class="strategy-add-section">
|
||||
<el-select
|
||||
v-model="selectedStrategy"
|
||||
placeholder="选择要添加的策略"
|
||||
style="width: 200px;"
|
||||
clearable
|
||||
>
|
||||
<el-option
|
||||
v-for="opt in EXTRACT_STRATEGY_OPTIONS"
|
||||
v-for="opt in availableStrategies"
|
||||
:key="opt.value"
|
||||
:label="opt.label"
|
||||
:value="opt.value"
|
||||
:value="opt.value as ExtractStrategy"
|
||||
:disabled="isStrategySelected(opt.value as ExtractStrategy)"
|
||||
>
|
||||
<div class="extract-option">
|
||||
<span>{{ opt.label }}</span>
|
||||
|
|
@ -130,9 +220,22 @@
|
|||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<el-button
|
||||
type="primary"
|
||||
:disabled="!selectedStrategy"
|
||||
@click="addStrategy"
|
||||
>
|
||||
<el-icon><Plus /></el-icon>
|
||||
添加
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
<div v-if="strategyError" class="strategy-error">
|
||||
<el-text type="danger">{{ strategyError }}</el-text>
|
||||
</div>
|
||||
</div>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item label="关联字段" prop="linked_field_id">
|
||||
<el-select
|
||||
v-model="formData.linked_field_id"
|
||||
|
|
@ -186,10 +289,11 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onMounted, watch } from 'vue'
|
||||
import { ref, reactive, onMounted, watch, computed } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Plus, Edit, Delete } from '@element-plus/icons-vue'
|
||||
import { Plus, Edit, Delete, Rank, Close } from '@element-plus/icons-vue'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import draggable from 'vuedraggable'
|
||||
import { slotDefinitionApi } from '@/api/slot-definition'
|
||||
import { metadataSchemaApi } from '@/api/metadata-schema'
|
||||
import {
|
||||
|
|
@ -199,7 +303,9 @@ import {
|
|||
type SlotDefinitionCreateRequest,
|
||||
type SlotDefinitionUpdateRequest,
|
||||
type SlotType,
|
||||
type ExtractStrategy
|
||||
type ExtractStrategy,
|
||||
validateExtractStrategies,
|
||||
getEffectiveStrategies
|
||||
} from '@/types/slot-definition'
|
||||
import type { MetadataFieldDefinition } from '@/types/metadata'
|
||||
|
||||
|
|
@ -212,12 +318,19 @@ const isEdit = ref(false)
|
|||
const submitting = ref(false)
|
||||
const formRef = ref<FormInstance>()
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 策略链相关
|
||||
const selectedStrategy = ref<ExtractStrategy | ''>('')
|
||||
const strategyError = ref('')
|
||||
|
||||
const formData = reactive({
|
||||
id: '',
|
||||
slot_key: '',
|
||||
display_name: '',
|
||||
description: '',
|
||||
type: 'string' as SlotType,
|
||||
required: false,
|
||||
extract_strategy: '' as ExtractStrategy | '',
|
||||
// [AC-MRS-07-UPGRADE] 使用策略链数组
|
||||
extract_strategies: [] as ExtractStrategy[],
|
||||
validation_rule: '',
|
||||
ask_back_prompt: '',
|
||||
default_value: '',
|
||||
|
|
@ -233,6 +346,42 @@ const formRules: FormRules = {
|
|||
required: [{ required: true, message: '请选择是否必填', trigger: 'change' }]
|
||||
}
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 计算可用策略(未选择的)
|
||||
const availableStrategies = computed(() => {
|
||||
return EXTRACT_STRATEGY_OPTIONS
|
||||
})
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 检查策略是否已选择
|
||||
const isStrategySelected = (strategy: ExtractStrategy) => {
|
||||
return formData.extract_strategies.includes(strategy)
|
||||
}
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 添加策略到链
|
||||
const addStrategy = () => {
|
||||
if (!selectedStrategy.value) return
|
||||
|
||||
if (isStrategySelected(selectedStrategy.value)) {
|
||||
strategyError.value = '该策略已存在于链中'
|
||||
return
|
||||
}
|
||||
|
||||
formData.extract_strategies.push(selectedStrategy.value)
|
||||
selectedStrategy.value = ''
|
||||
strategyError.value = ''
|
||||
}
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 从链中移除策略
|
||||
const removeStrategy = (index: number) => {
|
||||
formData.extract_strategies.splice(index, 1)
|
||||
strategyError.value = ''
|
||||
}
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 清空策略链
|
||||
const clearStrategies = () => {
|
||||
formData.extract_strategies = []
|
||||
strategyError.value = ''
|
||||
}
|
||||
|
||||
const getTypeLabel = (type: SlotType) => {
|
||||
return SLOT_TYPE_OPTIONS.find(o => o.value === type)?.label || type
|
||||
}
|
||||
|
|
@ -241,6 +390,16 @@ const getExtractStrategyLabel = (strategy: ExtractStrategy) => {
|
|||
return EXTRACT_STRATEGY_OPTIONS.find(o => o.value === strategy)?.label || strategy
|
||||
}
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 获取策略标签类型
|
||||
const getStrategyTagType = (strategy: ExtractStrategy): any => {
|
||||
const typeMap: Record<ExtractStrategy, any> = {
|
||||
'rule': 'success',
|
||||
'llm': 'warning',
|
||||
'user_input': 'info'
|
||||
}
|
||||
return typeMap[strategy] || 'info'
|
||||
}
|
||||
|
||||
const fetchSlots = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
|
|
@ -256,7 +415,8 @@ const fetchSlots = async () => {
|
|||
const fetchSlotFields = async () => {
|
||||
try {
|
||||
const res = await metadataSchemaApi.getByRole('slot', true)
|
||||
slotFields.value = res.items || []
|
||||
// 后端返回数组格式,兼容处理
|
||||
slotFields.value = Array.isArray(res) ? res : (res.items || [])
|
||||
} catch (error: any) {
|
||||
console.error('获取槽位角色字段失败', error)
|
||||
}
|
||||
|
|
@ -267,14 +427,19 @@ const handleCreate = () => {
|
|||
Object.assign(formData, {
|
||||
id: '',
|
||||
slot_key: '',
|
||||
display_name: '',
|
||||
description: '',
|
||||
type: 'string',
|
||||
required: false,
|
||||
extract_strategy: '',
|
||||
// [AC-MRS-07-UPGRADE] 初始化为空数组
|
||||
extract_strategies: [],
|
||||
validation_rule: '',
|
||||
ask_back_prompt: '',
|
||||
default_value: '',
|
||||
linked_field_id: ''
|
||||
})
|
||||
selectedStrategy.value = ''
|
||||
strategyError.value = ''
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
|
|
@ -283,14 +448,19 @@ const handleEdit = (slot: SlotDefinition) => {
|
|||
Object.assign(formData, {
|
||||
id: slot.id,
|
||||
slot_key: slot.slot_key,
|
||||
display_name: slot.display_name || '',
|
||||
description: slot.description || '',
|
||||
type: slot.type,
|
||||
required: slot.required,
|
||||
extract_strategy: slot.extract_strategy || '',
|
||||
// [AC-MRS-07-UPGRADE] 使用有效的策略链
|
||||
extract_strategies: [...getEffectiveStrategies(slot)],
|
||||
validation_rule: slot.validation_rule || '',
|
||||
ask_back_prompt: slot.ask_back_prompt || '',
|
||||
default_value: slot.default_value ?? '',
|
||||
linked_field_id: slot.linked_field_id || ''
|
||||
})
|
||||
selectedStrategy.value = ''
|
||||
strategyError.value = ''
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
|
|
@ -317,13 +487,25 @@ const handleSubmit = async () => {
|
|||
await formRef.value.validate(async (valid) => {
|
||||
if (!valid) return
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 校验策略链
|
||||
if (formData.extract_strategies.length > 0) {
|
||||
const validation = validateExtractStrategies(formData.extract_strategies)
|
||||
if (!validation.valid) {
|
||||
strategyError.value = validation.error || '策略链校验失败'
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const data: SlotDefinitionCreateRequest | SlotDefinitionUpdateRequest = {
|
||||
slot_key: formData.slot_key,
|
||||
display_name: formData.display_name || undefined,
|
||||
description: formData.description || undefined,
|
||||
type: formData.type,
|
||||
required: formData.required,
|
||||
extract_strategy: formData.extract_strategy || undefined,
|
||||
// [AC-MRS-07-UPGRADE] 提交策略链
|
||||
extract_strategies: formData.extract_strategies.length > 0 ? formData.extract_strategies : undefined,
|
||||
validation_rule: formData.validation_rule || undefined,
|
||||
ask_back_prompt: formData.ask_back_prompt || undefined,
|
||||
linked_field_id: formData.linked_field_id || undefined
|
||||
|
|
@ -409,6 +591,27 @@ onMounted(() => {
|
|||
font-size: 12px;
|
||||
}
|
||||
|
||||
.slot-key-cell {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
}
|
||||
|
||||
.display-name {
|
||||
font-size: 12px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.slot-description {
|
||||
color: var(--el-text-color-regular);
|
||||
font-size: 13px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
display: block;
|
||||
max-width: 180px;
|
||||
}
|
||||
|
||||
.no-value {
|
||||
color: var(--el-text-color-placeholder);
|
||||
}
|
||||
|
|
@ -439,6 +642,91 @@ onMounted(() => {
|
|||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
// [AC-MRS-07-UPGRADE] 策略链样式
|
||||
.strategy-chain {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.strategy-tag {
|
||||
margin-right: 4px;
|
||||
}
|
||||
|
||||
.strategy-chain-editor {
|
||||
border: 1px solid var(--el-border-color);
|
||||
border-radius: 4px;
|
||||
padding: 12px;
|
||||
background-color: var(--el-fill-color-light);
|
||||
}
|
||||
|
||||
.strategy-chain-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.chain-hint {
|
||||
font-size: 12px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.strategy-chain-list {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.strategy-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.strategy-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 12px;
|
||||
background-color: var(--el-bg-color);
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--el-border-color-light);
|
||||
}
|
||||
|
||||
.drag-handle {
|
||||
cursor: move;
|
||||
color: var(--el-text-color-secondary);
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.strategy-order {
|
||||
width: 20px;
|
||||
text-align: center;
|
||||
font-weight: 600;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.remove-btn {
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.strategy-empty {
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
color: var(--el-text-color-secondary);
|
||||
}
|
||||
|
||||
.strategy-add-section {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
padding-top: 12px;
|
||||
border-top: 1px dashed var(--el-border-color);
|
||||
}
|
||||
|
||||
.strategy-error {
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.extract-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
|
|
|||
|
|
@ -410,6 +410,24 @@
|
|||
<p>配置知识库、意图规则等的动态元数据字段定义。</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-item" @click="navigateTo('/admin/slot-definitions')">
|
||||
<div class="help-icon success">
|
||||
<el-icon><Grid /></el-icon>
|
||||
</div>
|
||||
<div class="help-text">
|
||||
<h4>槽位定义</h4>
|
||||
<p>定义槽位类型、提取策略、校验规则和追问提示。</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-item" @click="navigateTo('/admin/scene-slot-bundles')">
|
||||
<div class="help-icon warning">
|
||||
<el-icon><Collection /></el-icon>
|
||||
</div>
|
||||
<div class="help-text">
|
||||
<h4>场景槽位包</h4>
|
||||
<p>配置场景与槽位的映射关系,定义每个场景需要采集的槽位。</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
|
@ -420,7 +438,7 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, reactive, computed, onMounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { FolderOpened, Document, ChatDotSquare, Monitor, Cpu, InfoFilled, Connection, Timer, DataLine, Aim, DocumentCopy, Share, Warning, Setting } from '@element-plus/icons-vue'
|
||||
import { FolderOpened, Document, ChatDotSquare, Monitor, Cpu, InfoFilled, Connection, Timer, DataLine, Aim, DocumentCopy, Share, Warning, Setting, Grid, Collection } from '@element-plus/icons-vue'
|
||||
import { getDashboardStats, type DashboardStats } from '@/api/dashboard'
|
||||
|
||||
const router = useRouter()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,14 @@
|
|||
<el-icon><Upload /></el-icon>
|
||||
上传文档
|
||||
</el-button>
|
||||
<el-button type="success" @click="handleBatchUploadClick">
|
||||
<el-icon><Upload /></el-icon>
|
||||
批量上传
|
||||
</el-button>
|
||||
<el-button type="warning" @click="handleJsonUploadClick">
|
||||
<el-icon><Upload /></el-icon>
|
||||
JSON上传
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -85,6 +93,8 @@
|
|||
</el-dialog>
|
||||
|
||||
<input ref="fileInput" type="file" style="display: none" @change="handleFileChange" />
|
||||
<input ref="batchFileInput" type="file" accept=".zip" style="display: none" @change="handleBatchFileChange" />
|
||||
<input ref="jsonFileInput" type="file" accept=".json,.jsonl,.txt" style="display: none" @change="handleJsonFileChange" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
|
@ -92,7 +102,7 @@
|
|||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Upload, Document, View, Delete } from '@element-plus/icons-vue'
|
||||
import { uploadDocument, listDocuments, getIndexJob, deleteDocument } from '@/api/kb'
|
||||
import { uploadDocument, listDocuments, getIndexJob, deleteDocument, batchUploadDocuments, jsonBatchUpload } from '@/api/kb'
|
||||
|
||||
interface DocumentItem {
|
||||
docId: string
|
||||
|
|
@ -242,11 +252,21 @@ onUnmounted(() => {
|
|||
})
|
||||
|
||||
const fileInput = ref<HTMLInputElement | null>(null)
|
||||
const batchFileInput = ref<HTMLInputElement | null>(null)
|
||||
const jsonFileInput = ref<HTMLInputElement | null>(null)
|
||||
|
||||
const handleUploadClick = () => {
|
||||
fileInput.value?.click()
|
||||
}
|
||||
|
||||
const handleBatchUploadClick = () => {
|
||||
batchFileInput.value?.click()
|
||||
}
|
||||
|
||||
const handleJsonUploadClick = () => {
|
||||
jsonFileInput.value?.click()
|
||||
}
|
||||
|
||||
const handleFileChange = async (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
|
|
@ -281,6 +301,111 @@ const handleFileChange = async (event: Event) => {
|
|||
target.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
const handleBatchFileChange = async (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
if (!file.name.toLowerCase().endsWith('.zip')) {
|
||||
ElMessage.error('请上传 .zip 格式的压缩包')
|
||||
target.value = ''
|
||||
return
|
||||
}
|
||||
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
formData.append('kb_id', 'kb_default')
|
||||
|
||||
try {
|
||||
loading.value = true
|
||||
const res: any = await batchUploadDocuments(formData)
|
||||
|
||||
const { total, succeeded, failed, results } = res
|
||||
|
||||
if (succeeded > 0) {
|
||||
ElMessage.success(`批量上传成功!成功: ${succeeded}, 失败: ${failed}, 总计: ${total}`)
|
||||
|
||||
for (const result of results) {
|
||||
if (result.status === 'created') {
|
||||
const newDoc: DocumentItem = {
|
||||
docId: result.docId,
|
||||
name: result.fileName,
|
||||
status: 'pending',
|
||||
jobId: result.jobId,
|
||||
createTime: new Date().toLocaleString('zh-CN')
|
||||
}
|
||||
tableData.value.unshift(newDoc)
|
||||
startPolling(result.jobId)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ElMessage.error('批量上传失败,请检查压缩包格式')
|
||||
}
|
||||
|
||||
if (failed > 0) {
|
||||
const failedItems = results.filter((r: any) => r.status === 'failed')
|
||||
console.error('Failed uploads:', failedItems)
|
||||
}
|
||||
|
||||
fetchDocuments()
|
||||
} catch (error: any) {
|
||||
ElMessage.error(error.message || '批量上传失败')
|
||||
console.error('Batch upload error:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
target.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
const handleJsonFileChange = async (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
const fileName = file.name.toLowerCase()
|
||||
if (!fileName.endsWith('.json') && !fileName.endsWith('.jsonl') && !fileName.endsWith('.txt')) {
|
||||
ElMessage.error('请上传 .json、.jsonl 或 .txt 格式的文件')
|
||||
target.value = ''
|
||||
return
|
||||
}
|
||||
|
||||
const kbId = '75c465fe-277d-455d-a30b-4b168adcc03b'
|
||||
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
formData.append('tenant_id', 'szmp@ash@2026')
|
||||
|
||||
try {
|
||||
loading.value = true
|
||||
ElMessage.info('正在上传 JSON 文件...')
|
||||
|
||||
const res: any = await jsonBatchUpload(kbId, formData)
|
||||
|
||||
if (res.success) {
|
||||
ElMessage.success(`JSON 批量上传成功!成功: ${res.succeeded}, 失败: ${res.failed}`)
|
||||
|
||||
if (res.valid_metadata_fields) {
|
||||
console.log('有效的元数据字段:', res.valid_metadata_fields)
|
||||
}
|
||||
|
||||
if (res.failed > 0) {
|
||||
const failedItems = res.results.filter((r: any) => !r.success)
|
||||
console.warn('失败的项目:', failedItems)
|
||||
}
|
||||
|
||||
fetchDocuments()
|
||||
} else {
|
||||
ElMessage.error('JSON 批量上传失败')
|
||||
}
|
||||
} catch (error: any) {
|
||||
ElMessage.error(error.message || 'JSON 批量上传失败')
|
||||
console.error('JSON batch upload error:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
target.value = ''
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -324,6 +449,7 @@ const handleFileChange = async (event: Event) => {
|
|||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.table-card {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Admin API routes for AI Service management.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
|
||||
[AC-MRS-07,08,16] Slot definition management endpoints.
|
||||
[AC-SCENE-SLOT-01] Scene slot bundle management endpoints.
|
||||
[AC-AISVC-RES-01~15] Retrieval strategy management endpoints.
|
||||
"""
|
||||
|
||||
from app.api.admin.api_key import router as api_key_router
|
||||
|
|
@ -18,6 +20,8 @@ from app.api.admin.metadata_schema import router as metadata_schema_router
|
|||
from app.api.admin.monitoring import router as monitoring_router
|
||||
from app.api.admin.prompt_templates import router as prompt_templates_router
|
||||
from app.api.admin.rag import router as rag_router
|
||||
from app.api.admin.retrieval_strategy import router as retrieval_strategy_router
|
||||
from app.api.admin.scene_slot_bundle import router as scene_slot_bundle_router
|
||||
from app.api.admin.script_flows import router as script_flows_router
|
||||
from app.api.admin.sessions import router as sessions_router
|
||||
from app.api.admin.slot_definition import router as slot_definition_router
|
||||
|
|
@ -38,6 +42,8 @@ __all__ = [
|
|||
"monitoring_router",
|
||||
"prompt_templates_router",
|
||||
"rag_router",
|
||||
"retrieval_strategy_router",
|
||||
"scene_slot_bundle_router",
|
||||
"script_flows_router",
|
||||
"sessions_router",
|
||||
"slot_definition_router",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Intent Rule Management API.
|
||||
[AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints.
|
||||
[AC-AISVC-96] Intent rule testing endpoint.
|
||||
[AC-AISVC-116] Fusion config management endpoints.
|
||||
[AC-AISVC-114] Intent vector generation endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -14,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.core.database import get_session
|
||||
from app.models.entities import IntentRuleCreate, IntentRuleUpdate
|
||||
from app.services.intent.models import DEFAULT_FUSION_CONFIG, FusionConfig
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
from app.services.intent.tester import IntentRuleTester
|
||||
|
||||
|
|
@ -21,6 +24,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"])
|
||||
|
||||
_fusion_config = FusionConfig()
|
||||
|
||||
|
||||
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
||||
"""Extract tenant ID from header."""
|
||||
|
|
@ -204,3 +209,109 @@ async def test_rule(
|
|||
result = await tester.test_rule(rule, [body.message], all_rules)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
class FusionConfigUpdate(BaseModel):
|
||||
"""Request body for updating fusion config."""
|
||||
|
||||
w_rule: float | None = None
|
||||
w_semantic: float | None = None
|
||||
w_llm: float | None = None
|
||||
semantic_threshold: float | None = None
|
||||
conflict_threshold: float | None = None
|
||||
gray_zone_threshold: float | None = None
|
||||
min_trigger_threshold: float | None = None
|
||||
clarify_threshold: float | None = None
|
||||
multi_intent_threshold: float | None = None
|
||||
llm_judge_enabled: bool | None = None
|
||||
semantic_matcher_enabled: bool | None = None
|
||||
semantic_matcher_timeout_ms: int | None = None
|
||||
llm_judge_timeout_ms: int | None = None
|
||||
semantic_top_k: int | None = None
|
||||
|
||||
|
||||
@router.get("/fusion-config")
|
||||
async def get_fusion_config() -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-116] Get current fusion configuration.
|
||||
"""
|
||||
logger.info("[AC-AISVC-116] Getting fusion config")
|
||||
return _fusion_config.to_dict()
|
||||
|
||||
|
||||
@router.put("/fusion-config")
|
||||
async def update_fusion_config(
|
||||
body: FusionConfigUpdate,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-116] Update fusion configuration.
|
||||
"""
|
||||
global _fusion_config
|
||||
|
||||
logger.info(f"[AC-AISVC-116] Updating fusion config: {body.model_dump()}")
|
||||
|
||||
current_dict = _fusion_config.to_dict()
|
||||
update_dict = body.model_dump(exclude_none=True)
|
||||
current_dict.update(update_dict)
|
||||
_fusion_config = FusionConfig.from_dict(current_dict)
|
||||
|
||||
return _fusion_config.to_dict()
|
||||
|
||||
|
||||
@router.post("/{rule_id}/generate-vector")
|
||||
async def generate_intent_vector(
|
||||
rule_id: uuid.UUID,
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-114] Generate intent vector for a rule.
|
||||
|
||||
Uses the rule's semantic_examples to generate an average vector.
|
||||
If no semantic_examples exist, returns an error.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-114] Generating intent vector for tenant={tenant_id}, rule_id={rule_id}"
|
||||
)
|
||||
|
||||
service = IntentRuleService(session)
|
||||
rule = await service.get_rule(tenant_id, rule_id)
|
||||
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Intent rule not found")
|
||||
|
||||
if not rule.semantic_examples:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rule has no semantic_examples. Please add semantic_examples first."
|
||||
)
|
||||
|
||||
try:
|
||||
from app.core.dependencies import get_embedding_provider
|
||||
embedding_provider = get_embedding_provider()
|
||||
|
||||
vectors = await embedding_provider.embed_batch(rule.semantic_examples)
|
||||
|
||||
import numpy as np
|
||||
avg_vector = np.mean(vectors, axis=0).tolist()
|
||||
|
||||
update_data = IntentRuleUpdate(intent_vector=avg_vector)
|
||||
updated_rule = await service.update_rule(tenant_id, rule_id, update_data)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-114] Generated intent vector for rule={rule_id}, "
|
||||
f"dimension={len(avg_vector)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(updated_rule.id),
|
||||
"intent_vector": updated_rule.intent_vector,
|
||||
"semantic_examples": updated_rule.semantic_examples,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-114] Failed to generate intent vector: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate intent vector: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,28 +6,35 @@ Knowledge Base management endpoints.
|
|||
|
||||
import logging
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import tiktoken
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Query, UploadFile
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import (
|
||||
Document,
|
||||
DocumentStatus,
|
||||
IndexJob,
|
||||
IndexJobStatus,
|
||||
KBType,
|
||||
KnowledgeBase,
|
||||
KnowledgeBaseCreate,
|
||||
KnowledgeBaseUpdate,
|
||||
)
|
||||
from app.services.kb import KBService
|
||||
from app.services.knowledge_base_service import KnowledgeBaseService
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -457,6 +464,7 @@ async def list_documents(
|
|||
"kbId": doc.kb_id,
|
||||
"fileName": doc.file_name,
|
||||
"status": doc.status,
|
||||
"metadata": doc.doc_metadata,
|
||||
"jobId": str(latest_job.id) if latest_job else None,
|
||||
"createdAt": doc.created_at.isoformat() + "Z",
|
||||
"updatedAt": doc.updated_at.isoformat() + "Z",
|
||||
|
|
@ -585,6 +593,7 @@ async def upload_document(
|
|||
file_name=file.filename or "unknown",
|
||||
file_content=file_content,
|
||||
file_type=file.content_type,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
||||
|
|
@ -915,3 +924,488 @@ async def delete_document(
|
|||
"message": "Document deleted",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/documents/{doc_id}/metadata",
|
||||
operation_id="updateDocumentMetadata",
|
||||
summary="Update document metadata",
|
||||
description="[AC-ASA-08] Update metadata for a specific document.",
|
||||
responses={
|
||||
200: {"description": "Metadata updated"},
|
||||
404: {"description": "Document not found"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def update_document_metadata(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
doc_id: str,
|
||||
body: dict,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-08] Update document metadata.
|
||||
"""
|
||||
import json
|
||||
|
||||
metadata = body.get("metadata")
|
||||
|
||||
if metadata is not None and not isinstance(metadata, dict):
|
||||
try:
|
||||
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_METADATA",
|
||||
"message": "Invalid JSON format for metadata",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-08] Updating document metadata: tenant={tenant_id}, doc_id={doc_id}"
|
||||
)
|
||||
|
||||
from sqlalchemy import select
|
||||
from app.models.entities import Document
|
||||
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.id == doc_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "DOCUMENT_NOT_FOUND",
|
||||
"message": f"Document {doc_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
document.doc_metadata = metadata
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Metadata updated",
|
||||
"metadata": document.doc_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/documents/batch-upload",
|
||||
operation_id="batchUploadDocuments",
|
||||
summary="Batch upload documents from zip",
|
||||
description="Upload a zip file containing multiple folders, each with a markdown file and metadata.json",
|
||||
responses={
|
||||
200: {"description": "Batch upload result"},
|
||||
400: {"description": "Bad Request - invalid zip or missing files"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def batch_upload_documents(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
kb_id: str = Form(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Batch upload documents from a zip file.
|
||||
|
||||
Zip structure:
|
||||
- Each folder contains one .md file and one metadata.json
|
||||
- metadata.json uses field_key from MetadataFieldDefinition as keys
|
||||
|
||||
Example metadata.json:
|
||||
{
|
||||
"grade": "高一",
|
||||
"subject": "数学",
|
||||
"type": "痛点"
|
||||
}
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
|
||||
logger.info(
|
||||
f"[BATCH-UPLOAD] Starting batch upload: tenant={tenant_id}, "
|
||||
f"kb_id={kb_id}, filename={file.filename}"
|
||||
)
|
||||
|
||||
if not file.filename or not file.filename.lower().endswith('.zip'):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_FORMAT",
|
||||
"message": "Only .zip files are supported",
|
||||
},
|
||||
)
|
||||
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
|
||||
if not kb:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "KB_NOT_FOUND",
|
||||
"message": f"Knowledge base {kb_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
file_content = await file.read()
|
||||
|
||||
results = []
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = Path(temp_dir) / "upload.zip"
|
||||
with open(zip_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
zf.extractall(temp_dir)
|
||||
except zipfile.BadZipFile as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_ZIP",
|
||||
"message": f"Invalid zip file: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
extracted_path = Path(temp_dir)
|
||||
|
||||
# 列出解压后的所有内容,用于调试
|
||||
all_items = list(extracted_path.iterdir())
|
||||
logger.info(f"[BATCH-UPLOAD] Extracted items: {[item.name for item in all_items]}")
|
||||
|
||||
# 递归查找所有包含 content.txt/md 和 metadata.json 的文件夹
|
||||
def find_document_folders(path: Path) -> list[Path]:
|
||||
"""递归查找所有包含文档文件的文件夹"""
|
||||
doc_folders = []
|
||||
|
||||
# 检查当前文件夹是否包含文档文件
|
||||
content_files = (
|
||||
list(path.glob("*.md")) +
|
||||
list(path.glob("*.markdown")) +
|
||||
list(path.glob("*.txt"))
|
||||
)
|
||||
|
||||
if content_files:
|
||||
# 这个文件夹包含文档文件,是一个文档文件夹
|
||||
doc_folders.append(path)
|
||||
logger.info(f"[BATCH-UPLOAD] Found document folder: {path.name}, files: {[f.name for f in content_files]}")
|
||||
|
||||
# 递归检查子文件夹
|
||||
for subfolder in [p for p in path.iterdir() if p.is_dir()]:
|
||||
doc_folders.extend(find_document_folders(subfolder))
|
||||
|
||||
return doc_folders
|
||||
|
||||
folders = find_document_folders(extracted_path)
|
||||
|
||||
if not folders:
|
||||
logger.error(f"[BATCH-UPLOAD] No document folders found in zip. Items found: {[item.name for item in all_items]}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "NO_DOCUMENTS_FOUND",
|
||||
"message": "压缩包中没有找到包含 .txt/.md 文件的文件夹",
|
||||
"details": {
|
||||
"expected_structure": "每个文件夹应包含 content.txt (或 .md) 和 metadata.json",
|
||||
"found_items": [item.name for item in all_items],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"[BATCH-UPLOAD] Found {len(folders)} document folders")
|
||||
|
||||
for folder in folders:
|
||||
folder_name = folder.name if folder != extracted_path else "root"
|
||||
|
||||
content_files = (
|
||||
list(folder.glob("*.md")) +
|
||||
list(folder.glob("*.markdown")) +
|
||||
list(folder.glob("*.txt"))
|
||||
)
|
||||
|
||||
if not content_files:
|
||||
# 这种情况不应该发生,因为我们已经过滤过了
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": "No content file found",
|
||||
})
|
||||
continue
|
||||
|
||||
content_file = content_files[0]
|
||||
metadata_file = folder / "metadata.json"
|
||||
|
||||
metadata_dict = {}
|
||||
if metadata_file.exists():
|
||||
try:
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata_dict = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": f"Invalid metadata.json: {str(e)}",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"[BATCH-UPLOAD] No metadata.json in folder {folder_name}, using empty metadata")
|
||||
|
||||
field_def_service = MetadataFieldDefinitionService(session)
|
||||
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
||||
tenant_id, metadata_dict, "kb_document"
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": f"Metadata validation failed: {validation_errors}",
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(content_file, 'rb') as f:
|
||||
doc_content = f.read()
|
||||
|
||||
file_ext = content_file.suffix.lower()
|
||||
if file_ext == '.txt':
|
||||
file_type = "text/plain"
|
||||
else:
|
||||
file_type = "text/markdown"
|
||||
|
||||
doc_kb_service = KBService(session)
|
||||
document, job = await doc_kb_service.upload_document(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
file_name=content_file.name,
|
||||
file_content=doc_content,
|
||||
file_type=file_type,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
||||
await session.commit()
|
||||
|
||||
background_tasks.add_task(
|
||||
_index_document,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
str(job.id),
|
||||
str(document.id),
|
||||
doc_content,
|
||||
content_file.name,
|
||||
metadata_dict,
|
||||
)
|
||||
|
||||
succeeded += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"docId": str(document.id),
|
||||
"jobId": str(job.id),
|
||||
"status": "created",
|
||||
"fileName": content_file.name,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"[BATCH-UPLOAD] Created document: folder={folder_name}, "
|
||||
f"doc_id={document.id}, job_id={job.id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
})
|
||||
logger.error(f"[BATCH-UPLOAD] Failed to create document: folder={folder_name}, error={e}")
|
||||
|
||||
logger.info(
|
||||
f"[BATCH-UPLOAD] Completed: total={len(results)}, succeeded={succeeded}, failed={failed}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"total": len(results),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{kb_id}/documents/json-batch",
|
||||
summary="[AC-KB-03] JSON批量上传文档",
|
||||
description="上传JSONL格式文件,每行一个JSON对象,包含text和元数据字段",
|
||||
)
|
||||
async def upload_json_batch(
|
||||
kb_id: str,
|
||||
tenant_id: str = Query(..., description="租户ID"),
|
||||
file: UploadFile = File(..., description="JSONL格式文件,每行一个JSON对象"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
JSON批量上传文档
|
||||
|
||||
文件格式:JSONL (每行一个JSON对象)
|
||||
必填字段:text - 需要录入知识库的文本内容
|
||||
可选字段:元数据字段(如grade, subject, kb_scene等)
|
||||
|
||||
示例:
|
||||
{"text": "课程内容...", "grade": "初二", "subject": "数学", "kb_scene": "课程咨询"}
|
||||
{"text": "另一条课程内容...", "grade": "初三", "info_type": "课程概述"}
|
||||
"""
|
||||
kb = await session.get(KnowledgeBase, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
|
||||
if kb.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此知识库")
|
||||
|
||||
valid_field_keys = set()
|
||||
try:
|
||||
field_defs = await MetadataFieldDefinitionService(session).get_fields(
|
||||
tenant_id=tenant_id,
|
||||
include_inactive=False,
|
||||
)
|
||||
valid_field_keys = {f.field_key for f in field_defs}
|
||||
logger.info(f"[AC-KB-03] Valid metadata fields for tenant {tenant_id}: {valid_field_keys}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-KB-03] Failed to get metadata fields: {e}")
|
||||
|
||||
content = await file.read()
|
||||
try:
|
||||
text_content = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
text_content = content.decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="文件编码不支持,请使用UTF-8编码")
|
||||
|
||||
lines = text_content.strip().split("\n")
|
||||
if not lines:
|
||||
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||
|
||||
results = []
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
kb_service = KBService(session)
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": False,
|
||||
"error": f"JSON解析失败: {e}",
|
||||
})
|
||||
continue
|
||||
|
||||
text = json_obj.get("text")
|
||||
if not text:
|
||||
failed += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": False,
|
||||
"error": "缺少必填字段: text",
|
||||
})
|
||||
continue
|
||||
|
||||
metadata = {}
|
||||
for key, value in json_obj.items():
|
||||
if key == "text":
|
||||
continue
|
||||
if valid_field_keys and key not in valid_field_keys:
|
||||
logger.debug(f"[AC-KB-03] Skipping invalid metadata field: {key}")
|
||||
continue
|
||||
if value is not None:
|
||||
metadata[key] = value
|
||||
|
||||
try:
|
||||
file_name = f"json_batch_line_{line_num}.txt"
|
||||
file_content = text.encode("utf-8")
|
||||
|
||||
document, job = await kb_service.upload_document(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
file_content=file_content,
|
||||
file_type="text/plain",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
_index_document,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
str(job.id),
|
||||
str(document.id),
|
||||
file_content,
|
||||
file_name,
|
||||
metadata,
|
||||
)
|
||||
|
||||
succeeded += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": True,
|
||||
"doc_id": str(document.id),
|
||||
"job_id": str(job.id),
|
||||
"metadata": metadata,
|
||||
})
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
})
|
||||
logger.error(f"[AC-KB-03] Failed to upload document at line {line_num}: {e}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[AC-KB-03] JSON batch upload completed: kb_id={kb_id}, total={len(lines)}, succeeded={succeeded}, failed={failed}")
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"total": len(lines),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"valid_metadata_fields": list(valid_field_keys) if valid_field_keys else [],
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
|
|||
"scope": f.scope,
|
||||
"is_filterable": f.is_filterable,
|
||||
"is_rank_feature": f.is_rank_feature,
|
||||
"usage_description": f.usage_description,
|
||||
"field_roles": f.field_roles or [],
|
||||
"status": f.status,
|
||||
"version": f.version,
|
||||
|
|
|
|||
|
|
@ -407,6 +407,7 @@ async def get_conversation_detail(
|
|||
"guardrailTriggered": user_msg.guardrail_triggered,
|
||||
"guardrailWords": user_msg.guardrail_words,
|
||||
"executionSteps": execution_steps,
|
||||
"routeTrace": user_msg.route_trace,
|
||||
"createdAt": user_msg.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
|
@ -659,8 +660,56 @@ async def _process_export(
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}")
|
||||
|
||||
task = await session.get(ExportTask, task_id)
|
||||
task = task_status.get(ExportTask, task_id)
|
||||
if task:
|
||||
task.status = ExportTaskStatus.FAILED.value
|
||||
task.error_message = str(e)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.get("/clarification-metrics")
|
||||
async def get_clarification_metrics(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
total_requests: int = Query(100, ge=1, description="Total requests for rate calculation"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-CLARIFY] Get clarification metrics.
|
||||
|
||||
Returns:
|
||||
- clarify_trigger_rate: 澄清触发率
|
||||
- clarify_converge_rate: 澄清后收敛率
|
||||
- misroute_rate: 误入流程率
|
||||
"""
|
||||
from app.services.intent.clarification import get_clarify_metrics
|
||||
|
||||
metrics = get_clarify_metrics()
|
||||
counts = metrics.get_metrics()
|
||||
rates = metrics.get_rates(total_requests)
|
||||
|
||||
return {
|
||||
"counts": counts,
|
||||
"rates": rates,
|
||||
"thresholds": {
|
||||
"t_high": 0.75,
|
||||
"t_low": 0.45,
|
||||
"max_retry": 3,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/clarification-metrics/reset")
|
||||
async def reset_clarification_metrics(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-CLARIFY] Reset clarification metrics.
|
||||
"""
|
||||
from app.services.intent.clarification import get_clarify_metrics
|
||||
|
||||
metrics = get_clarify_metrics()
|
||||
metrics.reset()
|
||||
|
||||
return {
|
||||
"status": "reset",
|
||||
"message": "Clarification metrics have been reset.",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,17 @@ from app.models.entities import PromptTemplateCreate, PromptTemplateUpdate
|
|||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.monitoring.prompt_monitor import PromptMonitor
|
||||
|
||||
|
||||
class PromptTemplateUpdateAPI(BaseModel):
|
||||
"""API model for updating prompt template with metadata field mapping."""
|
||||
name: str | None = None
|
||||
scene: str | None = None
|
||||
description: str | None = None
|
||||
system_instruction: str | None = None
|
||||
variables: list[dict[str, Any]] | None = None
|
||||
is_default: bool | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/prompt-templates", tags=["Prompt Management"])
|
||||
|
|
@ -108,29 +119,47 @@ async def get_template_detail(
|
|||
@router.put("/{tpl_id}")
|
||||
async def update_template(
|
||||
tpl_id: uuid.UUID,
|
||||
body: PromptTemplateUpdate,
|
||||
body: PromptTemplateUpdateAPI,
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-53] Update prompt template (creates a new version).
|
||||
[AC-IDSMETA-16] Return current_content and metadata for frontend display.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-53] Updating template for tenant={tenant_id}, id={tpl_id}")
|
||||
|
||||
service = PromptTemplateService(session)
|
||||
template = await service.update_template(tenant_id, tpl_id, body)
|
||||
|
||||
# Convert API model to entity model (metadata -> metadata_)
|
||||
update_data = PromptTemplateUpdate(
|
||||
name=body.name,
|
||||
scene=body.scene,
|
||||
description=body.description,
|
||||
system_instruction=body.system_instruction,
|
||||
variables=body.variables,
|
||||
is_default=body.is_default,
|
||||
metadata_=body.metadata,
|
||||
)
|
||||
|
||||
template = await service.update_template(tenant_id, tpl_id, update_data)
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
published_version = await service.get_published_version_info(tenant_id, template.id)
|
||||
|
||||
# Get latest version content for current_content
|
||||
latest_version = await service._get_latest_version(template.id)
|
||||
|
||||
return {
|
||||
"id": str(template.id),
|
||||
"name": template.name,
|
||||
"scene": template.scene,
|
||||
"description": template.description,
|
||||
"is_default": template.is_default,
|
||||
"current_content": latest_version.system_instruction if latest_version else None,
|
||||
"metadata": template.metadata_,
|
||||
"published_version": published_version,
|
||||
"created_at": template.created_at.isoformat(),
|
||||
"updated_at": template.updated_at.isoformat(),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,265 @@
|
|||
"""
|
||||
Scene Slot Bundle API.
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置管理接口
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models.entities import SceneSlotBundleCreate, SceneSlotBundleUpdate
|
||||
from app.services.scene_slot_bundle_service import SceneSlotBundleService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/scene-slot-bundles", tags=["SceneSlotBundle"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Get current tenant ID from context."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def _bundle_to_dict(bundle: Any) -> dict[str, Any]:
|
||||
"""Convert bundle to dict"""
|
||||
return {
|
||||
"id": str(bundle.id),
|
||||
"tenant_id": str(bundle.tenant_id),
|
||||
"scene_key": bundle.scene_key,
|
||||
"scene_name": bundle.scene_name,
|
||||
"description": bundle.description,
|
||||
"required_slots": bundle.required_slots,
|
||||
"optional_slots": bundle.optional_slots,
|
||||
"slot_priority": bundle.slot_priority,
|
||||
"completion_threshold": bundle.completion_threshold,
|
||||
"ask_back_order": bundle.ask_back_order,
|
||||
"status": bundle.status,
|
||||
"version": bundle.version,
|
||||
"created_at": bundle.created_at.isoformat() if bundle.created_at else None,
|
||||
"updated_at": bundle.updated_at.isoformat() if bundle.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listSceneSlotBundles",
|
||||
summary="List scene slot bundles",
|
||||
description="获取场景槽位包列表",
|
||||
)
|
||||
async def list_scene_slot_bundles(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
status: Annotated[str | None, Query(
|
||||
description="按状态过滤: draft/active/deprecated"
|
||||
)] = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
列出场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"Listing scene slot bundles: tenant={tenant_id}, status={status}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
bundles = await service.list_bundles(tenant_id, status)
|
||||
|
||||
return JSONResponse(
|
||||
content=[_bundle_to_dict(b) for b in bundles]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
operation_id="createSceneSlotBundle",
|
||||
summary="Create scene slot bundle",
|
||||
description="[AC-SCENE-SLOT-01] 创建新的场景槽位包",
|
||||
status_code=201,
|
||||
)
|
||||
async def create_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
bundle_create: SceneSlotBundleCreate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 创建场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Creating scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene_key={bundle_create.scene_key}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
|
||||
try:
|
||||
bundle = await service.create_bundle(tenant_id, bundle_create)
|
||||
await session.commit()
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content=_bundle_to_dict(bundle)
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-scene/{scene_key}",
|
||||
operation_id="getSceneSlotBundleBySceneKey",
|
||||
summary="Get scene slot bundle by scene key",
|
||||
description="根据场景标识获取槽位包",
|
||||
)
|
||||
async def get_scene_slot_bundle_by_scene_key(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
scene_key: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
根据场景标识获取槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"Getting scene slot bundle by scene_key: tenant={tenant_id}, scene_key={scene_key}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
bundle = await service.get_bundle_by_scene_key(tenant_id, scene_key)
|
||||
|
||||
if not bundle:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle with scene_key '{scene_key}' not found",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=_bundle_to_dict(bundle))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{id}",
|
||||
operation_id="getSceneSlotBundle",
|
||||
summary="Get scene slot bundle by ID",
|
||||
description="获取单个场景槽位包详情(含槽位详情)",
|
||||
)
|
||||
async def get_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取单个场景槽位包详情
|
||||
"""
|
||||
logger.info(
|
||||
f"Getting scene slot bundle: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
bundle = await service.get_bundle_with_slot_details(tenant_id, id)
|
||||
|
||||
if not bundle:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=bundle)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{id}",
|
||||
operation_id="updateSceneSlotBundle",
|
||||
summary="Update scene slot bundle",
|
||||
description="更新场景槽位包",
|
||||
)
|
||||
async def update_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
bundle_update: SceneSlotBundleUpdate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
更新场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"Updating scene slot bundle: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
|
||||
try:
|
||||
bundle = await service.update_bundle(tenant_id, id, bundle_update)
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
if not bundle:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(content=_bundle_to_dict(bundle))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{id}",
|
||||
operation_id="deleteSceneSlotBundle",
|
||||
summary="Delete scene slot bundle",
|
||||
description="[AC-SCENE-SLOT-01] 删除场景槽位包",
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 删除场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Deleting scene slot bundle: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
success = await service.delete_bundle(tenant_id, id)
|
||||
|
||||
if not success:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle not found: {id}",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(status_code=204, content=None)
|
||||
|
|
@ -38,9 +38,13 @@ def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]:
|
|||
"id": str(slot.id),
|
||||
"tenant_id": str(slot.tenant_id),
|
||||
"slot_key": slot.slot_key,
|
||||
"display_name": slot.display_name,
|
||||
"description": slot.description,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"extract_strategies": slot.extract_strategies,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,342 @@
|
|||
"""
|
||||
Retrieval Strategy API Endpoints.
|
||||
[AC-AISVC-RES-01~15] API for strategy management and configuration.
|
||||
|
||||
Endpoints:
|
||||
- GET /strategy/retrieval/current - Get current strategy configuration
|
||||
- POST /strategy/retrieval/switch - Switch strategy configuration
|
||||
- POST /strategy/retrieval/validate - Validate strategy configuration
|
||||
- POST /strategy/retrieval/rollback - Rollback to default strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.services.retrieval.routing_config import (
|
||||
RagRuntimeMode,
|
||||
StrategyType,
|
||||
RoutingConfig,
|
||||
)
|
||||
from app.services.retrieval.strategy_router import (
|
||||
get_strategy_router,
|
||||
RollbackRecord,
|
||||
)
|
||||
from app.services.retrieval.mode_router import get_mode_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/strategy/retrieval", tags=["Retrieval Strategy"])
|
||||
|
||||
|
||||
class RoutingConfigRequest(BaseModel):
|
||||
"""Request model for routing configuration."""
|
||||
enabled: bool | None = Field(default=None, description="Enable strategy routing")
|
||||
strategy: StrategyType | None = Field(default=None, description="Retrieval strategy")
|
||||
grayscale_percentage: float | None = Field(default=None, ge=0.0, le=1.0, description="Grayscale percentage")
|
||||
grayscale_allowlist: list[str] | None = Field(default=None, description="Grayscale allowlist")
|
||||
rag_runtime_mode: RagRuntimeMode | None = Field(default=None, description="RAG runtime mode")
|
||||
react_trigger_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
react_trigger_complexity_score: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
react_max_steps: int | None = Field(default=None, ge=3, le=10)
|
||||
direct_fallback_on_low_confidence: bool | None = Field(default=None)
|
||||
direct_fallback_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
performance_budget_ms: int | None = Field(default=None, ge=1000)
|
||||
performance_degradation_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class RoutingConfigResponse(BaseModel):
|
||||
"""Response model for routing configuration."""
|
||||
enabled: bool
|
||||
strategy: StrategyType
|
||||
grayscale_percentage: float
|
||||
grayscale_allowlist: list[str]
|
||||
rag_runtime_mode: RagRuntimeMode
|
||||
react_trigger_confidence_threshold: float
|
||||
react_trigger_complexity_score: float
|
||||
react_max_steps: int
|
||||
direct_fallback_on_low_confidence: bool
|
||||
direct_fallback_confidence_threshold: float
|
||||
performance_budget_ms: int
|
||||
performance_degradation_threshold: float
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""Response model for configuration validation."""
|
||||
is_valid: bool
|
||||
errors: list[str]
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RollbackRequest(BaseModel):
|
||||
"""Request model for strategy rollback."""
|
||||
reason: str = Field(..., description="Reason for rollback")
|
||||
tenant_id: str | None = Field(default=None, description="Optional tenant ID for audit")
|
||||
|
||||
|
||||
class RollbackResponse(BaseModel):
|
||||
"""Response model for strategy rollback."""
|
||||
success: bool
|
||||
previous_strategy: StrategyType
|
||||
current_strategy: StrategyType
|
||||
reason: str
|
||||
rollback_records: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RollbackRecordResponse(BaseModel):
|
||||
"""Response model for rollback record."""
|
||||
timestamp: float
|
||||
from_strategy: StrategyType
|
||||
to_strategy: StrategyType
|
||||
reason: str
|
||||
tenant_id: str | None
|
||||
request_id: str | None
|
||||
|
||||
|
||||
class CurrentStrategyResponse(BaseModel):
|
||||
"""Response model for current strategy."""
|
||||
config: RoutingConfigResponse
|
||||
current_strategy: StrategyType
|
||||
rollback_records: list[RollbackRecordResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/current",
|
||||
operation_id="getCurrentRetrievalStrategy",
|
||||
summary="Get current retrieval strategy configuration",
|
||||
description="[AC-AISVC-RES-01] Returns the current strategy configuration and recent rollback records.",
|
||||
response_model=CurrentStrategyResponse,
|
||||
)
|
||||
async def get_current_strategy() -> CurrentStrategyResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-01] Get current retrieval strategy configuration.
|
||||
"""
|
||||
strategy_router = get_strategy_router()
|
||||
config = strategy_router.config
|
||||
rollback_records = strategy_router.get_rollback_records(limit=5)
|
||||
|
||||
return CurrentStrategyResponse(
|
||||
config=RoutingConfigResponse(
|
||||
enabled=config.enabled,
|
||||
strategy=config.strategy,
|
||||
grayscale_percentage=config.grayscale_percentage,
|
||||
grayscale_allowlist=config.grayscale_allowlist,
|
||||
rag_runtime_mode=config.rag_runtime_mode,
|
||||
react_trigger_confidence_threshold=config.react_trigger_confidence_threshold,
|
||||
react_trigger_complexity_score=config.react_trigger_complexity_score,
|
||||
react_max_steps=config.react_max_steps,
|
||||
direct_fallback_on_low_confidence=config.direct_fallback_on_low_confidence,
|
||||
direct_fallback_confidence_threshold=config.direct_fallback_confidence_threshold,
|
||||
performance_budget_ms=config.performance_budget_ms,
|
||||
performance_degradation_threshold=config.performance_degradation_threshold,
|
||||
),
|
||||
current_strategy=strategy_router.current_strategy,
|
||||
rollback_records=[
|
||||
RollbackRecordResponse(
|
||||
timestamp=r.timestamp,
|
||||
from_strategy=r.from_strategy,
|
||||
to_strategy=r.to_strategy,
|
||||
reason=r.reason,
|
||||
tenant_id=r.tenant_id,
|
||||
request_id=r.request_id,
|
||||
)
|
||||
for r in rollback_records
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/switch",
|
||||
operation_id="switchRetrievalStrategy",
|
||||
summary="Switch retrieval strategy configuration",
|
||||
description="[AC-AISVC-RES-02, AC-AISVC-RES-03] Update strategy configuration with hot reload support.",
|
||||
response_model=RoutingConfigResponse,
|
||||
)
|
||||
async def switch_strategy(
|
||||
request: RoutingConfigRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> RoutingConfigResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-15]
|
||||
Switch retrieval strategy configuration.
|
||||
|
||||
Supports:
|
||||
- Strategy selection (default/enhanced)
|
||||
- Grayscale release (percentage/allowlist)
|
||||
- Mode selection (direct/react/auto)
|
||||
- Hot reload of routing parameters
|
||||
"""
|
||||
strategy_router = get_strategy_router()
|
||||
mode_router = get_mode_router()
|
||||
current_config = strategy_router.config
|
||||
|
||||
new_config = RoutingConfig(
|
||||
enabled=request.enabled if request.enabled is not None else current_config.enabled,
|
||||
strategy=request.strategy if request.strategy is not None else current_config.strategy,
|
||||
grayscale_percentage=request.grayscale_percentage if request.grayscale_percentage is not None else current_config.grayscale_percentage,
|
||||
grayscale_allowlist=request.grayscale_allowlist if request.grayscale_allowlist is not None else current_config.grayscale_allowlist,
|
||||
rag_runtime_mode=request.rag_runtime_mode if request.rag_runtime_mode is not None else current_config.rag_runtime_mode,
|
||||
react_trigger_confidence_threshold=request.react_trigger_confidence_threshold if request.react_trigger_confidence_threshold is not None else current_config.react_trigger_confidence_threshold,
|
||||
react_trigger_complexity_score=request.react_trigger_complexity_score if request.react_trigger_complexity_score is not None else current_config.react_trigger_complexity_score,
|
||||
react_max_steps=request.react_max_steps if request.react_max_steps is not None else current_config.react_max_steps,
|
||||
direct_fallback_on_low_confidence=request.direct_fallback_on_low_confidence if request.direct_fallback_on_low_confidence is not None else current_config.direct_fallback_on_low_confidence,
|
||||
direct_fallback_confidence_threshold=request.direct_fallback_confidence_threshold if request.direct_fallback_confidence_threshold is not None else current_config.direct_fallback_confidence_threshold,
|
||||
performance_budget_ms=request.performance_budget_ms if request.performance_budget_ms is not None else current_config.performance_budget_ms,
|
||||
performance_degradation_threshold=request.performance_degradation_threshold if request.performance_degradation_threshold is not None else current_config.performance_degradation_threshold,
|
||||
)
|
||||
|
||||
is_valid, errors = new_config.validate()
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"errors": errors},
|
||||
)
|
||||
|
||||
strategy_router.update_config(new_config)
|
||||
mode_router.update_config(new_config)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-02, AC-AISVC-RES-15] Strategy switched: "
|
||||
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
|
||||
)
|
||||
|
||||
return RoutingConfigResponse(
|
||||
enabled=new_config.enabled,
|
||||
strategy=new_config.strategy,
|
||||
grayscale_percentage=new_config.grayscale_percentage,
|
||||
grayscale_allowlist=new_config.grayscale_allowlist,
|
||||
rag_runtime_mode=new_config.rag_runtime_mode,
|
||||
react_trigger_confidence_threshold=new_config.react_trigger_confidence_threshold,
|
||||
react_trigger_complexity_score=new_config.react_trigger_complexity_score,
|
||||
react_max_steps=new_config.react_max_steps,
|
||||
direct_fallback_on_low_confidence=new_config.direct_fallback_on_low_confidence,
|
||||
direct_fallback_confidence_threshold=new_config.direct_fallback_confidence_threshold,
|
||||
performance_budget_ms=new_config.performance_budget_ms,
|
||||
performance_degradation_threshold=new_config.performance_degradation_threshold,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/validate",
|
||||
operation_id="validateRetrievalStrategy",
|
||||
summary="Validate retrieval strategy configuration",
|
||||
description="[AC-AISVC-RES-06] Validate strategy configuration for completeness and consistency.",
|
||||
response_model=ValidationResponse,
|
||||
)
|
||||
async def validate_strategy(
|
||||
request: RoutingConfigRequest,
|
||||
) -> ValidationResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-06] Validate strategy configuration.
|
||||
|
||||
Checks:
|
||||
- Parameter value ranges
|
||||
- Configuration consistency
|
||||
- Performance budget constraints
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
config = RoutingConfig(
|
||||
enabled=request.enabled if request.enabled is not None else True,
|
||||
strategy=request.strategy if request.strategy is not None else StrategyType.DEFAULT,
|
||||
grayscale_percentage=request.grayscale_percentage if request.grayscale_percentage is not None else 0.0,
|
||||
grayscale_allowlist=request.grayscale_allowlist or [],
|
||||
rag_runtime_mode=request.rag_runtime_mode if request.rag_runtime_mode is not None else RagRuntimeMode.AUTO,
|
||||
react_trigger_confidence_threshold=request.react_trigger_confidence_threshold if request.react_trigger_confidence_threshold is not None else 0.6,
|
||||
react_trigger_complexity_score=request.react_trigger_complexity_score if request.react_trigger_complexity_score is not None else 0.5,
|
||||
react_max_steps=request.react_max_steps if request.react_max_steps is not None else 5,
|
||||
direct_fallback_on_low_confidence=request.direct_fallback_on_low_confidence if request.direct_fallback_on_low_confidence is not None else True,
|
||||
direct_fallback_confidence_threshold=request.direct_fallback_confidence_threshold if request.direct_fallback_confidence_threshold is not None else 0.4,
|
||||
performance_budget_ms=request.performance_budget_ms if request.performance_budget_ms is not None else 5000,
|
||||
performance_degradation_threshold=request.performance_degradation_threshold if request.performance_degradation_threshold is not None else 0.2,
|
||||
)
|
||||
|
||||
is_valid, errors = config.validate()
|
||||
|
||||
if config.grayscale_percentage > 0.5:
|
||||
warnings.append("grayscale_percentage > 50% may affect production stability")
|
||||
|
||||
if config.react_max_steps > 7:
|
||||
warnings.append("react_max_steps > 7 may cause high latency")
|
||||
|
||||
if config.direct_fallback_confidence_threshold > config.react_trigger_confidence_threshold:
|
||||
warnings.append(
|
||||
"direct_fallback_confidence_threshold > react_trigger_confidence_threshold "
|
||||
"may cause frequent fallbacks"
|
||||
)
|
||||
|
||||
if config.performance_budget_ms < 3000:
|
||||
warnings.append("performance_budget_ms < 3000ms may be too aggressive for complex queries")
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-06] Strategy validation: is_valid={is_valid}, "
|
||||
f"errors={len(errors)}, warnings={len(warnings)}"
|
||||
)
|
||||
|
||||
return ValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rollback",
|
||||
operation_id="rollbackRetrievalStrategy",
|
||||
summary="Rollback to default retrieval strategy",
|
||||
description="[AC-AISVC-RES-07] Force rollback to default strategy with audit logging.",
|
||||
response_model=RollbackResponse,
|
||||
)
|
||||
async def rollback_strategy(
|
||||
request: RollbackRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> RollbackResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-07] Rollback to default strategy.
|
||||
|
||||
Records the rollback event for audit and monitoring.
|
||||
"""
|
||||
strategy_router = get_strategy_router()
|
||||
mode_router = get_mode_router()
|
||||
|
||||
previous_strategy = strategy_router.current_strategy
|
||||
|
||||
strategy_router.rollback(
|
||||
reason=request.reason,
|
||||
tenant_id=request.tenant_id,
|
||||
)
|
||||
|
||||
new_config = RoutingConfig(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
rag_runtime_mode=RagRuntimeMode.AUTO,
|
||||
)
|
||||
mode_router.update_config(new_config)
|
||||
|
||||
rollback_records = strategy_router.get_rollback_records(limit=5)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-07] Strategy rollback: "
|
||||
f"from={previous_strategy.value}, to=DEFAULT, reason={request.reason}"
|
||||
)
|
||||
|
||||
return RollbackResponse(
|
||||
success=True,
|
||||
previous_strategy=previous_strategy,
|
||||
current_strategy=StrategyType.DEFAULT,
|
||||
reason=request.reason,
|
||||
rollback_records=[
|
||||
{
|
||||
"timestamp": r.timestamp,
|
||||
"from_strategy": r.from_strategy.value,
|
||||
"to_strategy": r.to_strategy.value,
|
||||
"reason": r.reason,
|
||||
"tenant_id": r.tenant_id,
|
||||
}
|
||||
for r in rollback_records
|
||||
],
|
||||
)
|
||||
|
|
@ -55,6 +55,18 @@ from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer
|
|||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.mid.tool_registry import ToolRegistry
|
||||
from app.services.mid.trace_logger import TraceLogger
|
||||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
from app.services.intent.clarification import (
|
||||
ClarificationEngine,
|
||||
ClarifyReason,
|
||||
ClarifySessionManager,
|
||||
ClarifyState,
|
||||
HybridIntentResult,
|
||||
IntentCandidate,
|
||||
T_HIGH,
|
||||
T_LOW,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -118,6 +130,7 @@ _kb_search_dynamic_registered: bool = False
|
|||
_intent_hint_registered: bool = False
|
||||
_high_risk_check_registered: bool = False
|
||||
_memory_recall_registered: bool = False
|
||||
_metadata_discovery_registered: bool = False
|
||||
|
||||
|
||||
def ensure_kb_search_dynamic_registered(
|
||||
|
|
@ -134,7 +147,7 @@ def ensure_kb_search_dynamic_registered(
|
|||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=2000,
|
||||
timeout_ms=10000,
|
||||
min_score_threshold=0.5,
|
||||
)
|
||||
|
||||
|
|
@ -222,6 +235,26 @@ def ensure_memory_recall_registered(
|
|||
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
|
||||
|
||||
|
||||
def ensure_metadata_discovery_registered(
|
||||
registry: ToolRegistry,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
"""Ensure metadata_discovery tool is registered."""
|
||||
global _metadata_discovery_registered
|
||||
if _metadata_discovery_registered:
|
||||
return
|
||||
|
||||
from app.services.mid.metadata_discovery_tool import register_metadata_discovery_tool
|
||||
|
||||
register_metadata_discovery_tool(
|
||||
registry=registry,
|
||||
session=session,
|
||||
timeout_governor=get_timeout_governor(),
|
||||
)
|
||||
_metadata_discovery_registered = True
|
||||
logger.info("[MetadataDiscovery] metadata_discovery tool registered to registry")
|
||||
|
||||
|
||||
def get_output_guardrail_executor() -> OutputGuardrailExecutor:
|
||||
"""Get or create OutputGuardrailExecutor instance."""
|
||||
if "output_guardrail_executor" not in _mid_services:
|
||||
|
|
@ -259,6 +292,16 @@ def get_runtime_observer() -> RuntimeObserver:
|
|||
return _mid_services["runtime_observer"]
|
||||
|
||||
|
||||
def get_clarification_engine() -> ClarificationEngine:
|
||||
"""Get or create ClarificationEngine instance."""
|
||||
if "clarification_engine" not in _mid_services:
|
||||
_mid_services["clarification_engine"] = ClarificationEngine(
|
||||
t_high=T_HIGH,
|
||||
t_low=T_LOW,
|
||||
)
|
||||
return _mid_services["clarification_engine"]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/dialogue/respond",
|
||||
operation_id="respondDialogue",
|
||||
|
|
@ -288,6 +331,7 @@ async def respond_dialogue(
|
|||
default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner),
|
||||
segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer),
|
||||
runtime_observer: RuntimeObserver = Depends(get_runtime_observer),
|
||||
clarification_engine: ClarificationEngine = Depends(get_clarification_engine),
|
||||
) -> DialogueResponse:
|
||||
"""
|
||||
[AC-MARH-01~12] Generate dialogue response with segments and trace.
|
||||
|
|
@ -316,7 +360,8 @@ async def respond_dialogue(
|
|||
|
||||
logger.info(
|
||||
f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, "
|
||||
f"session={dialogue_request.session_id}, request_id={request_id}"
|
||||
f"session={dialogue_request.session_id}, request_id={request_id}, "
|
||||
f"user_message={dialogue_request.user_message[:100] if dialogue_request.user_message else 'None'}..."
|
||||
)
|
||||
|
||||
runtime_ctx = runtime_observer.start_observation(
|
||||
|
|
@ -340,6 +385,7 @@ async def respond_dialogue(
|
|||
ensure_intent_hint_registered(tool_registry, session)
|
||||
ensure_high_risk_check_registered(tool_registry, session)
|
||||
ensure_memory_recall_registered(tool_registry, session)
|
||||
ensure_metadata_discovery_registered(tool_registry, session)
|
||||
|
||||
try:
|
||||
interrupt_ctx = interrupt_context_enricher.enrich(
|
||||
|
|
@ -541,7 +587,15 @@ async def respond_dialogue(
|
|||
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[AC-IDMP-06] Dialogue error: {e}")
|
||||
import traceback
|
||||
logger.error(
|
||||
f"[AC-IDMP-06] Dialogue error: {e}\n"
|
||||
f"Exception type: {type(e).__name__}\n"
|
||||
f"Request details: session_id={dialogue_request.session_id}, "
|
||||
f"user_message={dialogue_request.user_message[:200] if dialogue_request.user_message else 'None'}, "
|
||||
f"scene={dialogue_request.scene}, user_id={dialogue_request.user_id}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
trace_logger.update_trace(
|
||||
request_id=request_id,
|
||||
|
|
@ -593,7 +647,7 @@ async def _match_intent(
|
|||
return IntentMatch(
|
||||
intent_id=str(result.rule.id),
|
||||
intent_name=result.rule.name,
|
||||
confidence=0.8,
|
||||
confidence=1.0,
|
||||
response_type=result.rule.response_type,
|
||||
target_kb_ids=result.rule.target_kb_ids,
|
||||
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
|
||||
|
|
@ -608,6 +662,227 @@ async def _match_intent(
|
|||
return None
|
||||
|
||||
|
||||
async def _match_intent_hybrid(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
session: AsyncSession,
|
||||
clarification_engine: ClarificationEngine,
|
||||
) -> HybridIntentResult:
|
||||
"""
|
||||
[AC-CLARIFY] Hybrid intent matching with clarification support.
|
||||
|
||||
Returns HybridIntentResult with unified confidence and candidates.
|
||||
"""
|
||||
try:
|
||||
from app.services.intent.router import IntentRouter
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
|
||||
rule_service = IntentRuleService(session)
|
||||
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
|
||||
|
||||
if not rules:
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
router = IntentRouter()
|
||||
|
||||
try:
|
||||
fusion_result = await router.match_hybrid(
|
||||
message=request.user_message,
|
||||
rules=rules,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
hybrid_result = HybridIntentResult.from_fusion_result(fusion_result)
|
||||
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Hybrid intent match: "
|
||||
f"intent={hybrid_result.intent.intent_name if hybrid_result.intent else None}, "
|
||||
f"confidence={hybrid_result.confidence:.3f}, "
|
||||
f"need_clarify={hybrid_result.need_clarify}"
|
||||
)
|
||||
|
||||
return hybrid_result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-CLARIFY] Hybrid match failed, fallback to rule: {e}")
|
||||
|
||||
result = router.match(request.user_message, rules)
|
||||
|
||||
if result:
|
||||
confidence = clarification_engine.compute_confidence(
|
||||
rule_score=1.0,
|
||||
semantic_score=0.0,
|
||||
llm_score=0.0,
|
||||
)
|
||||
candidate = IntentCandidate(
|
||||
intent_id=str(result.rule.id),
|
||||
intent_name=result.rule.name,
|
||||
confidence=confidence,
|
||||
response_type=result.rule.response_type,
|
||||
target_kb_ids=result.rule.target_kb_ids,
|
||||
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
|
||||
fixed_reply=result.rule.fixed_reply,
|
||||
transfer_message=result.rule.transfer_message,
|
||||
)
|
||||
return HybridIntentResult(
|
||||
intent=candidate,
|
||||
confidence=confidence,
|
||||
candidates=[candidate],
|
||||
)
|
||||
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-CLARIFY] Intent match failed: {e}")
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
|
||||
async def _handle_clarification(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
hybrid_result: HybridIntentResult,
|
||||
clarification_engine: ClarificationEngine,
|
||||
trace: TraceInfo,
|
||||
session: AsyncSession,
|
||||
required_slots: list[str] | None = None,
|
||||
filled_slots: dict[str, Any] | None = None,
|
||||
) -> DialogueResponse | None:
|
||||
"""
|
||||
[AC-CLARIFY] Handle clarification logic.
|
||||
|
||||
Returns DialogueResponse if clarification is needed, None otherwise.
|
||||
"""
|
||||
existing_state = ClarifySessionManager.get_session(request.session_id)
|
||||
|
||||
if existing_state and not existing_state.is_max_retry():
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Processing clarify response: "
|
||||
f"session={request.session_id}, retry={existing_state.retry_count}"
|
||||
)
|
||||
|
||||
new_result = clarification_engine.process_clarify_response(
|
||||
user_message=request.user_message,
|
||||
state=existing_state,
|
||||
)
|
||||
|
||||
if not new_result.need_clarify:
|
||||
ClarifySessionManager.clear_session(request.session_id)
|
||||
|
||||
if new_result.intent:
|
||||
return None
|
||||
|
||||
clarify_prompt = clarification_engine.generate_clarify_prompt(existing_state)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=clarify_prompt, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.FIXED,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code=f"clarify_retry_{existing_state.retry_count}",
|
||||
intent=existing_state.candidates[0].intent_name if existing_state.candidates else None,
|
||||
),
|
||||
)
|
||||
|
||||
should_clarify, clarify_state = clarification_engine.should_trigger_clarify(
|
||||
result=hybrid_result,
|
||||
required_slots=required_slots,
|
||||
filled_slots=filled_slots,
|
||||
)
|
||||
|
||||
if not should_clarify or not clarify_state:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Clarification triggered: "
|
||||
f"session={request.session_id}, reason={clarify_state.reason}, "
|
||||
f"confidence={hybrid_result.confidence:.3f}"
|
||||
)
|
||||
|
||||
ClarifySessionManager.set_session(request.session_id, clarify_state)
|
||||
|
||||
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=clarify_prompt, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.FIXED,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code=f"clarify_{clarify_state.reason.value}",
|
||||
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _check_hard_block(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
hybrid_result: HybridIntentResult,
|
||||
clarification_engine: ClarificationEngine,
|
||||
trace: TraceInfo,
|
||||
required_slots: list[str] | None = None,
|
||||
filled_slots: dict[str, Any] | None = None,
|
||||
) -> DialogueResponse | None:
|
||||
"""
|
||||
[AC-CLARIFY] Check hard block conditions.
|
||||
|
||||
Hard blocks:
|
||||
1. confidence < T_high: block entering new flow
|
||||
2. required_slots missing: block flow progression
|
||||
|
||||
Returns DialogueResponse if blocked, None otherwise.
|
||||
"""
|
||||
is_blocked, block_reason = clarification_engine.check_hard_block(
|
||||
result=hybrid_result,
|
||||
required_slots=required_slots,
|
||||
filled_slots=filled_slots,
|
||||
)
|
||||
|
||||
if not is_blocked:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Hard block triggered: "
|
||||
f"session={request.session_id}, reason={block_reason}, "
|
||||
f"confidence={hybrid_result.confidence:.3f}"
|
||||
)
|
||||
|
||||
clarify_state = ClarifyState(
|
||||
reason=block_reason,
|
||||
asked_slot=required_slots[0] if required_slots and block_reason == ClarifyReason.MISSING_SLOT else None,
|
||||
candidates=hybrid_result.candidates,
|
||||
)
|
||||
|
||||
ClarifySessionManager.set_session(request.session_id, clarify_state)
|
||||
clarification_engine._metrics.record_clarify_trigger()
|
||||
|
||||
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=clarify_prompt, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.FIXED,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code=f"hard_block_{block_reason.value}",
|
||||
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_legacy_response(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
|
|
@ -793,18 +1068,32 @@ async def _execute_agent_mode(
|
|||
session: AsyncSession | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
) -> DialogueResponse:
|
||||
"""[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall."""
|
||||
"""
|
||||
[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13, AC-MRS-SLOT-META-03]
|
||||
Execute agent mode with ReAct loop, KB tool, memory recall, and slot state aggregation.
|
||||
"""
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-AGENT] Starting _execute_agent_mode: tenant={tenant_id}, "
|
||||
f"session={request.session_id}, scene={request.scene}, user_id={request.user_id}, "
|
||||
f"user_message={request.user_message[:100] if request.user_message else 'None'}..."
|
||||
)
|
||||
|
||||
try:
|
||||
llm_manager = get_llm_config_manager()
|
||||
llm_client = llm_manager.get_client()
|
||||
logger.info(f"[DEBUG-AGENT] LLM client obtained successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}")
|
||||
llm_client = None
|
||||
|
||||
base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {}
|
||||
|
||||
if request.scene:
|
||||
base_context["scene"] = request.scene
|
||||
|
||||
if interrupt_ctx and interrupt_ctx.consumed:
|
||||
base_context["interrupted_content"] = interrupt_ctx.interrupted_content
|
||||
base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids
|
||||
|
|
@ -813,8 +1102,50 @@ async def _execute_agent_mode(
|
|||
f"{len(interrupt_ctx.interrupted_content or '')} chars"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 初始化槽位状态聚合器
|
||||
slot_state = None
|
||||
memory_context = ""
|
||||
memory_missing_slots: list[str] = []
|
||||
scene_slot_context = None
|
||||
|
||||
flow_status = None
|
||||
is_flow_runtime = False
|
||||
if session:
|
||||
from app.services.flow.engine import FlowEngine
|
||||
flow_engine = FlowEngine(session)
|
||||
flow_status = await flow_engine.get_flow_status(tenant_id, request.session_id)
|
||||
is_flow_runtime = bool(flow_status and flow_status.get("status") == "active")
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-02] 运行轨道判定: session={request.session_id}, "
|
||||
f"is_flow_runtime={is_flow_runtime}, flow_id={flow_status.get('flow_id') if flow_status else None}, "
|
||||
f"current_step={flow_status.get('current_step') if flow_status else None}"
|
||||
)
|
||||
|
||||
# [AC-SCENE-SLOT-02] 加载场景槽位包(仅流程态加载)
|
||||
if is_flow_runtime and session and request.scene:
|
||||
from app.services.mid.scene_slot_bundle_loader import SceneSlotBundleLoader
|
||||
|
||||
scene_loader = SceneSlotBundleLoader(session)
|
||||
scene_slot_context = await scene_loader.load_scene_context(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=request.scene,
|
||||
)
|
||||
|
||||
if scene_slot_context:
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-02] 场景槽位包加载成功: "
|
||||
f"scene={request.scene}, required={len(scene_slot_context.required_slots)}, "
|
||||
f"optional={len(scene_slot_context.optional_slots)}, "
|
||||
f"threshold={scene_slot_context.completion_threshold}"
|
||||
)
|
||||
base_context["scene_slot_context"] = {
|
||||
"scene_key": scene_slot_context.scene_key,
|
||||
"scene_name": scene_slot_context.scene_name,
|
||||
"required_slots": scene_slot_context.get_required_slot_keys(),
|
||||
"optional_slots": scene_slot_context.get_optional_slot_keys(),
|
||||
}
|
||||
|
||||
if session and request.user_id:
|
||||
memory_recall_tool = MemoryRecallTool(
|
||||
session=session,
|
||||
|
|
@ -855,11 +1186,111 @@ async def _execute_agent_mode(
|
|||
f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 仅流程态执行槽位聚合/提取
|
||||
if is_flow_runtime:
|
||||
slot_aggregator = SlotStateAggregator(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
)
|
||||
slot_state = await slot_aggregator.aggregate(
|
||||
memory_slots=memory_result.slots,
|
||||
current_input_slots=None, # 可从 request 中解析
|
||||
context=base_context,
|
||||
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 流程态槽位聚合完成: "
|
||||
f"filled={len(slot_state.filled_slots)}, "
|
||||
f"missing={len(slot_state.missing_required_slots)}, "
|
||||
f"mappings={slot_state.slot_to_field_map}"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-EXTRACT-01] 自动提取槽位
|
||||
if slot_state and slot_state.missing_required_slots:
|
||||
from app.services.mid.slot_extraction_integration import SlotExtractionIntegration
|
||||
|
||||
extraction_integration = SlotExtractionIntegration(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
)
|
||||
|
||||
extraction_result = await extraction_integration.extract_and_fill(
|
||||
user_input=request.user_message,
|
||||
slot_state=slot_state,
|
||||
)
|
||||
|
||||
if extraction_result.extracted_slots:
|
||||
slot_state = await slot_aggregator.aggregate(
|
||||
memory_slots=memory_result.slots,
|
||||
current_input_slots=extraction_result.extracted_slots,
|
||||
context=base_context,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-EXTRACT-01] 流程态自动提槽完成: "
|
||||
f"extracted={list(extraction_result.extracted_slots.keys())}, "
|
||||
f"time_ms={extraction_result.total_execution_time_ms:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 当前为通用问答轨道,跳过槽位聚合与缺槽追问: "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
|
||||
kb_hits = []
|
||||
kb_success = False
|
||||
kb_fallback_reason = None
|
||||
kb_applied_filter = {}
|
||||
kb_missing_slots = []
|
||||
kb_dynamic_result = None
|
||||
step_kb_binding_trace: dict[str, Any] | None = None
|
||||
|
||||
# [Step-KB-Binding] 获取当前流程步骤的 KB 配置(仅流程态)
|
||||
step_kb_config = None
|
||||
if is_flow_runtime and session and flow_status:
|
||||
current_step_no = flow_status.get("current_step")
|
||||
flow_id = flow_status.get("flow_id")
|
||||
|
||||
if flow_id and current_step_no:
|
||||
from app.models.entities import ScriptFlow
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(ScriptFlow).where(ScriptFlow.id == flow_id)
|
||||
result = await session.execute(stmt)
|
||||
flow = result.scalar_one_or_none()
|
||||
|
||||
if flow and flow.steps:
|
||||
step_idx = current_step_no - 1
|
||||
if 0 <= step_idx < len(flow.steps):
|
||||
current_step = flow.steps[step_idx]
|
||||
|
||||
# 构建 StepKbConfig
|
||||
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
|
||||
step_kb_config = StepKbConfig(
|
||||
allowed_kb_ids=current_step.get("allowed_kb_ids"),
|
||||
preferred_kb_ids=current_step.get("preferred_kb_ids"),
|
||||
kb_query_hint=current_step.get("kb_query_hint"),
|
||||
max_kb_calls=current_step.get("max_kb_calls_per_step", 1),
|
||||
step_id=f"{flow_id}_step_{current_step_no}",
|
||||
)
|
||||
|
||||
step_kb_binding_trace = {
|
||||
"flow_id": str(flow_id),
|
||||
"flow_name": flow_status.get("flow_name"),
|
||||
"current_step": current_step_no,
|
||||
"step_id": step_kb_config.step_id,
|
||||
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
|
||||
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] 步骤知识库配置加载成功: "
|
||||
f"flow={flow_status.get('flow_name')}, step={current_step_no}, "
|
||||
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}"
|
||||
)
|
||||
|
||||
if session and tool_registry:
|
||||
kb_tool = KbSearchDynamicTool(
|
||||
|
|
@ -868,17 +1299,21 @@ async def _execute_agent_mode(
|
|||
config=KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=2000,
|
||||
timeout_ms=10000,
|
||||
min_score_threshold=0.5,
|
||||
),
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 传入 slot_state 进行 KB 检索
|
||||
# [Step-KB-Binding] 传入 step_kb_config 进行步骤级别的 KB 约束
|
||||
kb_dynamic_result = await kb_tool.execute(
|
||||
query=request.user_message,
|
||||
tenant_id=tenant_id,
|
||||
scene="open_consult",
|
||||
top_k=5,
|
||||
context=base_context,
|
||||
slot_state=slot_state,
|
||||
step_kb_config=step_kb_config,
|
||||
slot_policy="flow_strict" if is_flow_runtime else "agent_relaxed",
|
||||
)
|
||||
|
||||
kb_success = kb_dynamic_result.success
|
||||
|
|
@ -894,10 +1329,44 @@ async def _execute_agent_mode(
|
|||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] KB dynamic search: success={kb_success}, "
|
||||
f"[AC-MARH-05] KB动态检索完成: success={kb_success}, "
|
||||
f"hits={len(kb_hits)}, filter={kb_applied_filter}, "
|
||||
f"missing_slots={kb_missing_slots}"
|
||||
f"missing_slots={kb_missing_slots}, track={'flow' if is_flow_runtime else 'agent'}"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 处理缺失必填槽位 -> 追问闭环(仅流程态)
|
||||
if is_flow_runtime and kb_fallback_reason == "MISSING_REQUIRED_SLOTS" and kb_missing_slots:
|
||||
ask_back_text = await _generate_ask_back_for_missing_slots(
|
||||
slot_state=slot_state,
|
||||
missing_slots=kb_missing_slots,
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 流程态缺槽追问文案已生成: "
|
||||
f"{ask_back_text[:50]}..."
|
||||
)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=ask_back_text, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code="missing_required_slots",
|
||||
kb_tool_called=True,
|
||||
kb_hit=False,
|
||||
# [AC-SCENE-SLOT-02] 场景槽位追踪
|
||||
scene=request.scene,
|
||||
scene_slot_context=base_context.get("scene_slot_context"),
|
||||
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
|
||||
ask_back_triggered=True,
|
||||
slot_sources=slot_state.slot_sources if slot_state else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
kb_result = await default_kb_tool_runner.execute(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -933,7 +1402,18 @@ async def _execute_agent_mode(
|
|||
timeout_governor=timeout_governor,
|
||||
llm_client=llm_client,
|
||||
tool_registry=tool_registry,
|
||||
template_service=PromptTemplateService,
|
||||
variable_resolver=VariableResolver(),
|
||||
tenant_id=tenant_id,
|
||||
user_id=request.user_id,
|
||||
session_id=request.session_id,
|
||||
scene=request.scene,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-AGENT] Calling orchestrator.execute with: "
|
||||
f"user_message={request.user_message[:100] if request.user_message else 'None'}, "
|
||||
f"context_keys={list(base_context.keys())}, llm_client={llm_client is not None}"
|
||||
)
|
||||
|
||||
final_answer, react_ctx, agent_trace = await orchestrator.execute(
|
||||
|
|
@ -941,6 +1421,12 @@ async def _execute_agent_mode(
|
|||
context=base_context,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-AGENT] orchestrator.execute completed: "
|
||||
f"final_answer={final_answer[:200] if final_answer else 'None'}, "
|
||||
f"iterations={react_ctx.iteration}, tool_calls={len(react_ctx.tool_calls) if react_ctx.tool_calls else 0}"
|
||||
)
|
||||
|
||||
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
|
||||
|
||||
trace_logger.update_trace(
|
||||
|
|
@ -964,10 +1450,114 @@ async def _execute_agent_mode(
|
|||
kb_tool_called=True,
|
||||
kb_hit=kb_success and len(kb_hits) > 0,
|
||||
fallback_reason_code=kb_fallback_reason,
|
||||
# [AC-SCENE-SLOT-02] 场景槽位追踪
|
||||
scene=request.scene,
|
||||
scene_slot_context=base_context.get("scene_slot_context"),
|
||||
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
|
||||
ask_back_triggered=False,
|
||||
slot_sources=slot_state.slot_sources if slot_state else None,
|
||||
kb_filter_sources=kb_dynamic_result.filter_sources if kb_dynamic_result else None,
|
||||
# [Step-KB-Binding] 步骤知识库绑定追踪
|
||||
step_kb_binding=kb_dynamic_result.step_kb_binding if kb_dynamic_result else step_kb_binding_trace,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _generate_ask_back_for_missing_slots(
|
||||
slot_state: Any,
|
||||
missing_slots: list[dict[str, str]],
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
max_ask_back_slots: int = 2,
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> str:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-03, AC-MRS-SLOT-ASKBACK-01] 为缺失的必填槽位生成追问响应
|
||||
[AC-SCENE-SLOT-02] 支持场景槽位包配置的追问策略
|
||||
|
||||
支持批量追问多个缺失槽位,优先追问必填槽位
|
||||
"""
|
||||
if not missing_slots:
|
||||
return "请提供更多信息以便我更好地帮助您。"
|
||||
|
||||
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景配置的追问策略
|
||||
if scene_slot_context:
|
||||
ask_back_order = getattr(scene_slot_context, 'ask_back_order', 'priority')
|
||||
|
||||
if ask_back_order == "parallel":
|
||||
prompts = []
|
||||
for missing in missing_slots[:max_ask_back_slots]:
|
||||
ask_back_prompt = missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
prompts.append(ask_back_prompt)
|
||||
else:
|
||||
slot_key = missing.get("slot_key", "相关信息")
|
||||
prompts.append(f"您的{slot_key}")
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
elif len(prompts) == 2:
|
||||
return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。"
|
||||
else:
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。"
|
||||
|
||||
if len(missing_slots) == 1:
|
||||
first_missing = missing_slots[0]
|
||||
ask_back_prompt = first_missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
return ask_back_prompt
|
||||
slot_key = first_missing.get("slot_key", "相关信息")
|
||||
label = first_missing.get("label", slot_key)
|
||||
return f"为了更好地为您提供帮助,请告诉我您的{label}。"
|
||||
|
||||
try:
|
||||
from app.services.mid.batch_ask_back_service import (
|
||||
BatchAskBackConfig,
|
||||
BatchAskBackService,
|
||||
)
|
||||
|
||||
config = BatchAskBackConfig(
|
||||
max_ask_back_slots_per_turn=max_ask_back_slots,
|
||||
prefer_required=True,
|
||||
merge_prompts=True,
|
||||
)
|
||||
|
||||
ask_back_service = BatchAskBackService(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id or "",
|
||||
config=config,
|
||||
)
|
||||
|
||||
result = await ask_back_service.generate_batch_ask_back(
|
||||
missing_slots=missing_slots,
|
||||
)
|
||||
|
||||
if result.has_ask_back():
|
||||
return result.get_prompt()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MRS-SLOT-ASKBACK-01] Batch ask-back failed, fallback to single: {e}")
|
||||
|
||||
prompts = []
|
||||
for missing in missing_slots[:max_ask_back_slots]:
|
||||
ask_back_prompt = missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
prompts.append(ask_back_prompt)
|
||||
else:
|
||||
label = missing.get("label", missing.get("slot_key", "相关信息"))
|
||||
prompts.append(f"您的{label}")
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
elif len(prompts) == 2:
|
||||
return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。"
|
||||
else:
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。"
|
||||
|
||||
|
||||
async def _execute_micro_flow_mode(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Annotated
|
|||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
|
|
@ -26,6 +27,82 @@ router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"])
|
|||
_session_modes: dict[str, SessionMode] = {}
|
||||
|
||||
|
||||
class CancelFlowResponse(BaseModel):
|
||||
"""Response for cancel flow operation."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{sessionId}/cancel-flow",
|
||||
operation_id="cancelActiveFlow",
|
||||
summary="Cancel active flow",
|
||||
description="""
|
||||
Cancel the active flow for a session.
|
||||
|
||||
Use this when you encounter "Session already has an active flow" error.
|
||||
""",
|
||||
responses={
|
||||
200: {"description": "Flow cancelled successfully", "model": CancelFlowResponse},
|
||||
404: {"description": "No active flow found"},
|
||||
},
|
||||
)
|
||||
async def cancel_active_flow(
|
||||
sessionId: Annotated[str, Path(description="Session ID")],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> CancelFlowResponse:
|
||||
"""
|
||||
Cancel the active flow for a session.
|
||||
|
||||
This endpoint allows you to cancel any active flow instance
|
||||
so that a new flow can be started.
|
||||
"""
|
||||
tenant_id = get_tenant_id()
|
||||
|
||||
if not tenant_id:
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
raise MissingTenantIdException()
|
||||
|
||||
logger.info(
|
||||
f"[Cancel Flow] Cancelling active flow: tenant={tenant_id}, session={sessionId}"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.services.flow.engine import FlowEngine
|
||||
|
||||
flow_engine = FlowEngine(session)
|
||||
cancelled = await flow_engine.cancel_flow(
|
||||
tenant_id=tenant_id,
|
||||
session_id=sessionId,
|
||||
reason="User requested cancellation via API",
|
||||
)
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"[Cancel Flow] Flow cancelled: session={sessionId}")
|
||||
return CancelFlowResponse(
|
||||
success=True,
|
||||
message="Active flow cancelled successfully",
|
||||
session_id=sessionId,
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Cancel Flow] No active flow found: session={sessionId}")
|
||||
return CancelFlowResponse(
|
||||
success=True,
|
||||
message="No active flow found for this session",
|
||||
session_id=sessionId,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Cancel Flow] Failed to cancel flow: {e}")
|
||||
return CancelFlowResponse(
|
||||
success=False,
|
||||
message=f"Failed to cancel flow: {str(e)}",
|
||||
session_id=sessionId,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{sessionId}/mode",
|
||||
operation_id="switchSessionMode",
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class Settings(BaseSettings):
|
|||
redis_enabled: bool = True
|
||||
dashboard_cache_ttl: int = 60
|
||||
stats_counter_ttl: int = 7776000
|
||||
slot_state_cache_ttl: int = 1800
|
||||
|
||||
frontend_base_url: str = "http://localhost:3000"
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ engine = create_async_engine(
|
|||
settings.database_url,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
echo=settings.debug,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,12 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
|
|||
from app.core.database import async_session_maker
|
||||
async with async_session_maker() as session:
|
||||
await service.initialize(session)
|
||||
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
||||
if service._initialized and len(service._keys_cache) > 0:
|
||||
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
||||
elif service._initialized and len(service._keys_cache) == 0:
|
||||
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
|
||||
else:
|
||||
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -272,20 +272,24 @@ class QdrantClient:
|
|||
score_threshold: float | None = None,
|
||||
vector_name: str = "full",
|
||||
with_vectors: bool = False,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||
[AC-AISVC-10] Search vectors in tenant's collections.
|
||||
Returns results with score >= score_threshold if specified.
|
||||
Searches both old format (with @) and new format (with _) for backward compatibility.
|
||||
Searches all collections for the tenant (multi-KB support).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
query_vector: Query vector for similarity search
|
||||
limit: Maximum number of results
|
||||
limit: Maximum number of results per collection
|
||||
score_threshold: Minimum score threshold for results
|
||||
vector_name: Name of the vector to search (for multi-vector collections)
|
||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||
with_vectors: Whether to return vectors in results (for two-stage reranking)
|
||||
metadata_filter: Optional metadata filter to apply during search
|
||||
kb_ids: Optional list of knowledge base IDs to restrict search to specific KBs
|
||||
"""
|
||||
client = await self.get_client()
|
||||
|
||||
|
|
@ -293,21 +297,36 @@ class QdrantClient:
|
|||
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
||||
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
||||
)
|
||||
if metadata_filter:
|
||||
logger.info(f"[AC-AISVC-10] Metadata filter: {metadata_filter}")
|
||||
|
||||
collection_names = [self.get_collection_name(tenant_id)]
|
||||
if '@' in tenant_id:
|
||||
old_format = f"{self._collection_prefix}{tenant_id}"
|
||||
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
|
||||
collection_names = [new_format, old_format]
|
||||
# 构建 Qdrant filter
|
||||
qdrant_filter = None
|
||||
if metadata_filter:
|
||||
qdrant_filter = self._build_qdrant_filter(metadata_filter)
|
||||
logger.info(f"[AC-AISVC-10] Qdrant filter: {qdrant_filter}")
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
|
||||
# 获取该租户的所有 collections
|
||||
collection_names = await self._get_tenant_collections(client, tenant_id)
|
||||
|
||||
# 如果指定了 kb_ids,则只搜索指定的知识库 collections
|
||||
if kb_ids:
|
||||
target_collections = []
|
||||
for kb_id in kb_ids:
|
||||
kb_collection_name = self.get_kb_collection_name(tenant_id, kb_id)
|
||||
if kb_collection_name in collection_names:
|
||||
target_collections.append(kb_collection_name)
|
||||
else:
|
||||
logger.warning(f"[AC-AISVC-10] KB collection not found: {kb_collection_name} for kb_id={kb_id}")
|
||||
collection_names = target_collections
|
||||
logger.info(f"[AC-AISVC-10] Restricted to {len(collection_names)} KB collections: {collection_names}")
|
||||
else:
|
||||
logger.info(f"[AC-AISVC-10] Will search in {len(collection_names)} collections: {collection_names}")
|
||||
|
||||
all_hits = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
||||
|
||||
exists = await client.collection_exists(collection_name)
|
||||
if not exists:
|
||||
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
|
||||
|
|
@ -321,6 +340,7 @@ class QdrantClient:
|
|||
limit=limit,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
|
||||
|
|
@ -334,6 +354,7 @@ class QdrantClient:
|
|||
limit=limit,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -348,6 +369,7 @@ class QdrantClient:
|
|||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
"collection": collection_name, # 添加 collection 信息
|
||||
}
|
||||
if with_vectors and result.vector:
|
||||
hit["vector"] = result.vector
|
||||
|
|
@ -358,10 +380,6 @@ class QdrantClient:
|
|||
logger.info(
|
||||
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
||||
)
|
||||
for i, h in enumerate(hits[:3]):
|
||||
logger.debug(
|
||||
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
||||
|
|
@ -370,9 +388,10 @@ class QdrantClient:
|
|||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
|
||||
# 按分数排序并返回 top results
|
||||
all_hits.sort(key=lambda x: x["score"], reverse=True)
|
||||
all_hits = all_hits[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
||||
|
|
@ -386,6 +405,113 @@ class QdrantClient:
|
|||
|
||||
return all_hits
|
||||
|
||||
async def _get_tenant_collections(
|
||||
self,
|
||||
client: AsyncQdrantClient,
|
||||
tenant_id: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
获取指定租户的所有 collections。
|
||||
优先从 Redis 缓存获取,未缓存则从 Qdrant 查询并缓存。
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
Collection 名称列表
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
cache_key = f"collections:{tenant_id}"
|
||||
|
||||
try:
|
||||
# 确保 Redis 连接已初始化
|
||||
redis_client = await cache_service._get_redis()
|
||||
if redis_client and cache_service._enabled:
|
||||
cached = await redis_client.get(cache_key)
|
||||
if cached:
|
||||
import json
|
||||
collections = json.loads(cached)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Cache hit: Found {len(collections)} collections "
|
||||
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
|
||||
)
|
||||
return collections
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-10] Cache get error: {e}")
|
||||
|
||||
# 2. 从 Qdrant 查询
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"{self._collection_prefix}{safe_tenant_id}"
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
tenant_collections = [
|
||||
c.name for c in collections.collections
|
||||
if c.name.startswith(prefix)
|
||||
]
|
||||
|
||||
# 按名称排序
|
||||
tenant_collections.sort()
|
||||
|
||||
db_time = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Found {len(tenant_collections)} collections from Qdrant "
|
||||
f"for tenant={tenant_id} in {db_time:.2f}ms: {tenant_collections}"
|
||||
)
|
||||
|
||||
# 3. 缓存结果(5分钟 TTL)
|
||||
try:
|
||||
redis_client = await cache_service._get_redis()
|
||||
if redis_client and cache_service._enabled:
|
||||
import json
|
||||
await redis_client.setex(
|
||||
cache_key,
|
||||
300, # 5分钟
|
||||
json.dumps(tenant_collections)
|
||||
)
|
||||
logger.info(f"[AC-AISVC-10] Cached collections for tenant={tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-10] Cache set error: {e}")
|
||||
|
||||
return tenant_collections
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Failed to get collections for tenant={tenant_id}: {e}")
|
||||
return [self.get_collection_name(tenant_id)]
|
||||
|
||||
def _build_qdrant_filter(
|
||||
self,
|
||||
metadata_filter: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
构建 Qdrant 过滤条件。
|
||||
|
||||
Args:
|
||||
metadata_filter: 元数据过滤条件,如 {"grade": "三年级", "subject": "语文"}
|
||||
|
||||
Returns:
|
||||
Qdrant Filter 对象
|
||||
"""
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
must_conditions = []
|
||||
|
||||
for key, value in metadata_filter.items():
|
||||
# 支持嵌套 metadata 字段,如 metadata.grade
|
||||
field_path = f"metadata.{key}"
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
|
||||
return Filter(must=must_conditions) if must_conditions else None
|
||||
|
||||
async def delete_collection(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Delete tenant's collection.
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ from app.api.admin import (
|
|||
monitoring_router,
|
||||
prompt_templates_router,
|
||||
rag_router,
|
||||
retrieval_strategy_router,
|
||||
scene_slot_bundle_router,
|
||||
script_flows_router,
|
||||
sessions_router,
|
||||
slot_definition_router,
|
||||
|
|
@ -55,6 +57,11 @@ logging.basicConfig(
|
|||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -88,6 +95,28 @@ async def lifespan(app: FastAPI):
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
from app.services.mid.tool_guide_registry import init_tool_guide_registry
|
||||
|
||||
logger.info("[ToolGuideRegistry] Starting tool guides initialization...")
|
||||
tool_guide_registry = init_tool_guide_registry()
|
||||
logger.info(f"[ToolGuideRegistry] Tool guides loaded: {tool_guide_registry.list_tools()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ToolRegistry] Tools initialization FAILED: {e}", exc_info=True)
|
||||
|
||||
# [AC-AISVC-29] 预初始化 Embedding 服务,避免首次查询时的延迟
|
||||
try:
|
||||
from app.services.embedding import get_embedding_provider
|
||||
|
||||
logger.info("[AC-AISVC-29] Pre-initializing embedding service...")
|
||||
embedding_provider = await get_embedding_provider()
|
||||
logger.info(
|
||||
f"[AC-AISVC-29] Embedding service pre-initialized: "
|
||||
f"provider={embedding_provider.PROVIDER_NAME}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-29] Embedding service pre-initialization FAILED: {e}", exc_info=True)
|
||||
|
||||
yield
|
||||
|
||||
await close_db()
|
||||
|
|
@ -171,6 +200,8 @@ app.include_router(metadata_schema_router)
|
|||
app.include_router(monitoring_router)
|
||||
app.include_router(prompt_templates_router)
|
||||
app.include_router(rag_router)
|
||||
app.include_router(retrieval_strategy_router)
|
||||
app.include_router(scene_slot_bundle_router)
|
||||
app.include_router(script_flows_router)
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(slot_definition_router)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class ChatMessage(SQLModel, table=True):
|
|||
[AC-AISVC-13] Chat message entity with tenant isolation.
|
||||
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
|
||||
[v0.7.0] Extended with monitoring fields for Dashboard statistics.
|
||||
[v0.8.0] Extended with route_trace for hybrid routing observability.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_messages"
|
||||
|
|
@ -90,6 +91,11 @@ class ChatMessage(SQLModel, table=True):
|
|||
sa_column=Column("guardrail_words", JSON, nullable=True),
|
||||
description="[v0.7.0] Guardrail trigger details: words, categories, strategy"
|
||||
)
|
||||
route_trace: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("route_trace", JSON, nullable=True),
|
||||
description="[v0.8.0] Intent routing trace log for hybrid routing observability"
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionCreate(SQLModel):
|
||||
|
|
@ -227,6 +233,7 @@ class Document(SQLModel, table=True):
|
|||
file_type: str | None = Field(default=None, description="File MIME type")
|
||||
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
|
||||
error_msg: str | None = Field(default=None, description="Error message if failed")
|
||||
doc_metadata: dict | None = Field(default=None, sa_type=JSON, description="Document metadata as JSON")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
|
@ -421,6 +428,7 @@ class IntentRule(SQLModel, table=True):
|
|||
[AC-AISVC-65] Intent rule entity with tenant isolation.
|
||||
Supports keyword and regex matching for intent recognition.
|
||||
[AC-IDSMETA-16] Extended with metadata field for unified storage structure.
|
||||
[v0.8.0] Extended with intent_vector and semantic_examples for hybrid routing.
|
||||
"""
|
||||
|
||||
__tablename__ = "intent_rules"
|
||||
|
|
@ -458,6 +466,16 @@ class IntentRule(SQLModel, table=True):
|
|||
sa_column=Column("metadata", JSON, nullable=True),
|
||||
description="[AC-IDSMETA-16] Structured metadata for the intent rule"
|
||||
)
|
||||
intent_vector: list[float] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("intent_vector", JSON, nullable=True),
|
||||
description="[v0.8.0] Pre-computed intent vector for semantic matching"
|
||||
)
|
||||
semantic_examples: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("semantic_examples", JSON, nullable=True),
|
||||
description="[v0.8.0] Semantic example sentences for dynamic vector computation"
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
|
@ -475,6 +493,8 @@ class IntentRuleCreate(SQLModel):
|
|||
fixed_reply: str | None = None
|
||||
transfer_message: str | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
intent_vector: list[float] | None = None
|
||||
semantic_examples: list[str] | None = None
|
||||
|
||||
|
||||
class IntentRuleUpdate(SQLModel):
|
||||
|
|
@ -491,6 +511,8 @@ class IntentRuleUpdate(SQLModel):
|
|||
transfer_message: str | None = None
|
||||
is_enabled: bool | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
intent_vector: list[float] | None = None
|
||||
semantic_examples: list[str] | None = None
|
||||
|
||||
|
||||
class IntentMatchResult:
|
||||
|
|
@ -810,6 +832,24 @@ class FlowStep(SQLModel):
|
|||
default=None,
|
||||
description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}"
|
||||
)
|
||||
allowed_kb_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Step-KB-Binding] Allowed knowledge base IDs for this step. If set, KB search will be restricted to these KBs."
|
||||
)
|
||||
preferred_kb_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Step-KB-Binding] Preferred knowledge base IDs for this step. These KBs will be searched first."
|
||||
)
|
||||
kb_query_hint: str | None = Field(
|
||||
default=None,
|
||||
description="[Step-KB-Binding] Query hint for KB search in this step, helps improve retrieval accuracy."
|
||||
)
|
||||
max_kb_calls_per_step: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
le=5,
|
||||
description="[Step-KB-Binding] Max KB calls allowed per step. Default is 1 if not set."
|
||||
)
|
||||
|
||||
|
||||
class ScriptFlowCreate(SQLModel):
|
||||
|
|
@ -1078,6 +1118,7 @@ class MetadataFieldDefinition(SQLModel, table=True):
|
|||
)
|
||||
is_filterable: bool = Field(default=True, description="是否可用于过滤")
|
||||
is_rank_feature: bool = Field(default=False, description="是否用于排序特征")
|
||||
usage_description: str | None = Field(default=None, description="用途说明")
|
||||
field_roles: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"),
|
||||
|
|
@ -1104,6 +1145,7 @@ class MetadataFieldDefinitionCreate(SQLModel):
|
|||
scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value])
|
||||
is_filterable: bool = Field(default=True)
|
||||
is_rank_feature: bool = Field(default=False)
|
||||
usage_description: str | None = None
|
||||
field_roles: list[str] = Field(default_factory=list)
|
||||
status: str = Field(default=MetadataFieldStatus.DRAFT.value)
|
||||
|
||||
|
|
@ -1118,6 +1160,7 @@ class MetadataFieldDefinitionUpdate(SQLModel):
|
|||
scope: list[str] | None = None
|
||||
is_filterable: bool | None = None
|
||||
is_rank_feature: bool | None = None
|
||||
usage_description: str | None = None
|
||||
field_roles: list[str] | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
|
@ -1131,6 +1174,17 @@ class ExtractStrategy(str, Enum):
|
|||
USER_INPUT = "user_input"
|
||||
|
||||
|
||||
class ExtractFailureType(str, Enum):
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 提取失败类型
|
||||
统一失败分类,用于追踪和日志
|
||||
"""
|
||||
EXTRACT_EMPTY = "EXTRACT_EMPTY" # 提取结果为空
|
||||
EXTRACT_PARSE_FAIL = "EXTRACT_PARSE_FAIL" # 解析失败
|
||||
EXTRACT_VALIDATION_FAIL = "EXTRACT_VALIDATION_FAIL" # 校验失败
|
||||
EXTRACT_RUNTIME_ERROR = "EXTRACT_RUNTIME_ERROR" # 运行时错误
|
||||
|
||||
|
||||
class SlotValueSource(str, Enum):
|
||||
"""
|
||||
[AC-MRS-09] 槽位值来源
|
||||
|
|
@ -1145,6 +1199,7 @@ class SlotDefinition(SQLModel, table=True):
|
|||
"""
|
||||
[AC-MRS-07,08] 槽位定义表
|
||||
独立的槽位定义模型,与元数据字段解耦但可复用
|
||||
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
|
||||
"""
|
||||
|
||||
__tablename__ = "slot_definitions"
|
||||
|
|
@ -1162,14 +1217,31 @@ class SlotDefinition(SQLModel, table=True):
|
|||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="槽位名称,给运营/教研看的中文名,例:grade -> '当前年级'",
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||
max_length=500,
|
||||
)
|
||||
type: str = Field(
|
||||
default=MetadataFieldType.STRING.value,
|
||||
description="槽位类型: string/number/boolean/enum/array_enum"
|
||||
)
|
||||
required: bool = Field(default=False, description="是否必填槽位")
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容读取
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="提取策略: rule/llm/user_input"
|
||||
description="[兼容字段] 提取策略: rule/llm/user_input,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 新增策略链字段
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("extract_strategies", JSON, nullable=True),
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
validation_rule: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -1192,14 +1264,72 @@ class SlotDefinition(SQLModel, table=True):
|
|||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||
|
||||
def get_effective_strategies(self) -> list[str]:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 获取有效的提取策略链
|
||||
优先使用 extract_strategies,如果不存在则兼容读取 extract_strategy
|
||||
"""
|
||||
if self.extract_strategies and len(self.extract_strategies) > 0:
|
||||
return self.extract_strategies
|
||||
if self.extract_strategy:
|
||||
return [self.extract_strategy]
|
||||
return []
|
||||
|
||||
def validate_strategies(self) -> tuple[bool, str]:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
|
||||
|
||||
Returns:
|
||||
Tuple of (是否有效, 错误信息)
|
||||
"""
|
||||
valid_strategies = {"rule", "llm", "user_input"}
|
||||
strategies = self.get_effective_strategies()
|
||||
|
||||
if not strategies:
|
||||
return True, "" # 空策略链视为有效(使用默认行为)
|
||||
|
||||
# 校验至少1个策略
|
||||
if len(strategies) == 0:
|
||||
return False, "提取策略链不能为空"
|
||||
|
||||
# 校验不允许重复策略
|
||||
if len(strategies) != len(set(strategies)):
|
||||
return False, "提取策略链中不允许重复的策略"
|
||||
|
||||
# 校验策略值有效
|
||||
invalid = [s for s in strategies if s not in valid_strategies]
|
||||
if invalid:
|
||||
return False, f"无效的提取策略: {invalid},有效值为: {list(valid_strategies)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
class SlotDefinitionCreate(SQLModel):
|
||||
"""[AC-MRS-07,08] 创建槽位定义"""
|
||||
|
||||
slot_key: str = Field(..., min_length=1, max_length=100)
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="槽位名称,给运营/教研看的中文名",
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||
max_length=500,
|
||||
)
|
||||
type: str = Field(default=MetadataFieldType.STRING.value)
|
||||
required: bool = Field(default=False)
|
||||
extract_strategy: str | None = None
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
|
@ -1209,9 +1339,28 @@ class SlotDefinitionCreate(SQLModel):
|
|||
class SlotDefinitionUpdate(SQLModel):
|
||||
"""[AC-MRS-07] 更新槽位定义"""
|
||||
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="槽位名称,给运营/教研看的中文名",
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||
max_length=500,
|
||||
)
|
||||
type: str | None = None
|
||||
required: bool | None = None
|
||||
extract_strategy: str | None = None
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
|
@ -1522,3 +1671,107 @@ class MidAuditLog(SQLModel, table=True):
|
|||
high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景")
|
||||
latency_ms: int | None = Field(default=None, description="总耗时(ms)")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True)
|
||||
|
||||
|
||||
class SceneSlotBundleStatus(str, Enum):
|
||||
"""[AC-SCENE-SLOT-01] 场景槽位包状态"""
|
||||
DRAFT = "draft"
|
||||
ACTIVE = "active"
|
||||
DEPRECATED = "deprecated"
|
||||
|
||||
|
||||
class SceneSlotBundle(SQLModel, table=True):
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置
|
||||
定义每个场景需要采集的槽位集合
|
||||
|
||||
三层关系:
|
||||
- 层1:slot ↔ metadata(通过 linked_field_id)
|
||||
- 层2:scene ↔ slot_bundle(本模型)
|
||||
- 层3:step.expected_variables ↔ slot_key(话术步骤引用)
|
||||
"""
|
||||
|
||||
__tablename__ = "scene_slot_bundles"
|
||||
__table_args__ = (
|
||||
Index("ix_scene_slot_bundles_tenant", "tenant_id"),
|
||||
Index("ix_scene_slot_bundles_tenant_scene", "tenant_id", "scene_key", unique=True),
|
||||
Index("ix_scene_slot_bundles_tenant_status", "tenant_id", "status"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
scene_key: str = Field(
|
||||
...,
|
||||
description="场景标识,如 'open_consult', 'refund_apply', 'course_recommend'",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
scene_name: str = Field(
|
||||
...,
|
||||
description="场景名称,如 '开放咨询', '退款申请', '课程推荐'",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="场景描述"
|
||||
)
|
||||
required_slots: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column("required_slots", JSON, nullable=False),
|
||||
description="必填槽位 slot_key 列表"
|
||||
)
|
||||
optional_slots: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column("optional_slots", JSON, nullable=False),
|
||||
description="可选槽位 slot_key 列表"
|
||||
)
|
||||
slot_priority: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("slot_priority", JSON, nullable=True),
|
||||
description="槽位采集优先级顺序(slot_key 列表)"
|
||||
)
|
||||
completion_threshold: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="完成阈值(0.0-1.0),必填槽位填充比例达到此值视为完成"
|
||||
)
|
||||
ask_back_order: str = Field(
|
||||
default="priority",
|
||||
description="追问顺序策略: priority/required_first/parallel"
|
||||
)
|
||||
status: str = Field(
|
||||
default=SceneSlotBundleStatus.DRAFT.value,
|
||||
description="状态: draft/active/deprecated"
|
||||
)
|
||||
version: int = Field(default=1, description="版本号")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||
|
||||
|
||||
class SceneSlotBundleCreate(SQLModel):
|
||||
"""[AC-SCENE-SLOT-01] 创建场景槽位包"""
|
||||
|
||||
scene_key: str = Field(..., min_length=1, max_length=100)
|
||||
scene_name: str = Field(..., min_length=1, max_length=100)
|
||||
description: str | None = None
|
||||
required_slots: list[str] = Field(default_factory=list)
|
||||
optional_slots: list[str] = Field(default_factory=list)
|
||||
slot_priority: list[str] | None = None
|
||||
completion_threshold: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
ask_back_order: str = Field(default="priority")
|
||||
status: str = Field(default=SceneSlotBundleStatus.DRAFT.value)
|
||||
|
||||
|
||||
class SceneSlotBundleUpdate(SQLModel):
|
||||
"""[AC-SCENE-SLOT-01] 更新场景槽位包"""
|
||||
|
||||
scene_name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
description: str | None = None
|
||||
required_slots: list[str] | None = None
|
||||
optional_slots: list[str] | None = None
|
||||
slot_priority: list[str] | None = None
|
||||
completion_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
ask_back_order: str | None = None
|
||||
status: str | None = None
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class DialogueRequest(BaseModel):
|
|||
history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史")
|
||||
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段")
|
||||
feature_flags: FeatureFlags | None = Field(default=None, description="特性开关")
|
||||
scene: str | None = Field(default=None, description="场景标识,用于KB过滤,如 'open_consult', 'after_sale'")
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
|
|
@ -127,6 +128,7 @@ class TraceInfo(BaseModel):
|
|||
)
|
||||
tools_used: list[str] | None = Field(default=None, description="使用的工具列表")
|
||||
tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪")
|
||||
step_kb_binding: dict[str, Any] | None = Field(default=None, description="[Step-KB-Binding] 步骤知识库绑定信息")
|
||||
|
||||
|
||||
class DialogueResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class DialogueRequest(BaseModel):
|
|||
humanize_config: HumanizeConfigRequest | None = Field(
|
||||
default=None, description="Humanize config for segment delay"
|
||||
)
|
||||
scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'")
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
|
|
@ -122,6 +123,8 @@ class ToolCallTrace(BaseModel):
|
|||
error_code: str | None = Field(default=None, description="Error code if failed")
|
||||
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
|
||||
result_digest: str | None = Field(default=None, description="Result digest for logging")
|
||||
arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments")
|
||||
result: Any = Field(default=None, description="Full tool call result")
|
||||
|
||||
|
||||
class SegmentStats(BaseModel):
|
||||
|
|
@ -133,7 +136,7 @@ class SegmentStats(BaseModel):
|
|||
|
||||
class TraceInfo(BaseModel):
|
||||
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
|
||||
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability."""
|
||||
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability."""
|
||||
mode: ExecutionMode = Field(..., description="Execution mode")
|
||||
intent: str | None = Field(default=None, description="Matched intent")
|
||||
request_id: str | None = Field(
|
||||
|
|
@ -156,6 +159,17 @@ class TraceInfo(BaseModel):
|
|||
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
|
||||
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
|
||||
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
|
||||
duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds")
|
||||
created_at: str | None = Field(default=None, description="Creation timestamp")
|
||||
# [AC-SCENE-SLOT-02] 场景槽位追踪字段
|
||||
scene: str | None = Field(default=None, description="当前场景标识")
|
||||
scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息")
|
||||
missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表")
|
||||
ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问")
|
||||
slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射")
|
||||
kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射")
|
||||
# [Step-KB-Binding] 步骤知识库绑定追踪
|
||||
step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等")
|
||||
|
||||
|
||||
class DialogueResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ class ToolCallTrace:
|
|||
- error_code: 错误码
|
||||
- args_digest: 参数摘要(脱敏)
|
||||
- result_digest: 结果摘要
|
||||
- arguments: 完整参数
|
||||
- result: 完整结果
|
||||
"""
|
||||
tool_name: str
|
||||
duration_ms: int
|
||||
|
|
@ -53,6 +55,8 @@ class ToolCallTrace:
|
|||
error_code: str | None = None
|
||||
args_digest: str | None = None
|
||||
result_digest: str | None = None
|
||||
arguments: dict[str, Any] | None = None
|
||||
result: Any = None
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
|
@ -74,6 +78,10 @@ class ToolCallTrace:
|
|||
result["args_digest"] = self.args_digest
|
||||
if self.result_digest:
|
||||
result["result_digest"] = self.result_digest
|
||||
if self.arguments:
|
||||
result["arguments"] = self.arguments
|
||||
if self.result is not None:
|
||||
result["result"] = self.result
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -139,7 +139,16 @@ class SlotDefinitionResponse(BaseModel):
|
|||
slot_key: str = Field(..., description="槽位键名")
|
||||
type: str = Field(..., description="槽位类型")
|
||||
required: bool = Field(default=False, description="是否必填槽位")
|
||||
extract_strategy: str | None = Field(default=None, description="提取策略")
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 新增策略链字段
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input"
|
||||
)
|
||||
validation_rule: str | None = Field(default=None, description="校验规则")
|
||||
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
||||
default_value: dict[str, Any] | None = Field(default=None, description="默认值")
|
||||
|
|
@ -157,9 +166,15 @@ class SlotDefinitionCreateRequest(BaseModel):
|
|||
slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名")
|
||||
type: str = Field(default="string", description="槽位类型")
|
||||
required: bool = Field(default=False, description="是否必填槽位")
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="提取策略: rule/llm/user_input"
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = Field(default=None, description="校验规则")
|
||||
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
||||
|
|
@ -172,7 +187,16 @@ class SlotDefinitionUpdateRequest(BaseModel):
|
|||
|
||||
type: str | None = None
|
||||
required: bool | None = None
|
||||
extract_strategy: str | None = None
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
Retrieval Strategy Schemas for AI Service.
|
||||
[AC-AISVC-RES-01~15] Request/Response models aligned with OpenAPI contract.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class StrategyType(str, Enum):
|
||||
DEFAULT = "default"
|
||||
ENHANCED = "enhanced"
|
||||
|
||||
|
||||
class ReactMode(str, Enum):
|
||||
REACT = "react"
|
||||
NON_REACT = "non_react"
|
||||
|
||||
|
||||
class RolloutMode(str, Enum):
|
||||
OFF = "off"
|
||||
PERCENTAGE = "percentage"
|
||||
ALLOWLIST = "allowlist"
|
||||
|
||||
|
||||
class ValidationCheckType(str, Enum):
|
||||
METADATA_CONSISTENCY = "metadata_consistency"
|
||||
EMBEDDING_PREFIX = "embedding_prefix"
|
||||
RRF_CONFIG = "rrf_config"
|
||||
PERFORMANCE_BUDGET = "performance_budget"
|
||||
|
||||
|
||||
class RolloutConfig(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-03] Grayscale rollout configuration.
|
||||
"""
|
||||
|
||||
mode: RolloutMode = Field(..., description="Rollout mode: off, percentage, or allowlist")
|
||||
percentage: float | None = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
le=100,
|
||||
description="Percentage of traffic for grayscale (0-100)",
|
||||
)
|
||||
allowlist: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of tenant IDs in allowlist",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_rollout_config(self) -> "RolloutConfig":
|
||||
if self.mode == RolloutMode.PERCENTAGE and self.percentage is None:
|
||||
raise ValueError("percentage is required when mode is 'percentage'")
|
||||
if self.mode == RolloutMode.ALLOWLIST and (
|
||||
self.allowlist is None or len(self.allowlist) == 0
|
||||
):
|
||||
raise ValueError("allowlist is required when mode is 'allowlist'")
|
||||
return self
|
||||
|
||||
|
||||
class RetrievalStrategyStatus(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-01] Current retrieval strategy status.
|
||||
"""
|
||||
|
||||
active_strategy: StrategyType = Field(
|
||||
...,
|
||||
alias="active_strategy",
|
||||
description="Current active strategy: default or enhanced",
|
||||
)
|
||||
react_mode: ReactMode = Field(
|
||||
...,
|
||||
alias="react_mode",
|
||||
description="ReAct mode: react or non_react",
|
||||
)
|
||||
rollout: RolloutConfig = Field(..., description="Grayscale rollout configuration")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class RetrievalStrategySwitchRequest(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Request to switch retrieval strategy.
|
||||
"""
|
||||
|
||||
target_strategy: StrategyType = Field(
|
||||
...,
|
||||
alias="target_strategy",
|
||||
description="Target strategy to switch to",
|
||||
)
|
||||
react_mode: ReactMode | None = Field(
|
||||
default=None,
|
||||
alias="react_mode",
|
||||
description="ReAct mode to use",
|
||||
)
|
||||
rollout: RolloutConfig | None = Field(
|
||||
default=None,
|
||||
description="Grayscale rollout configuration",
|
||||
)
|
||||
reason: str | None = Field(
|
||||
default=None,
|
||||
description="Reason for strategy switch",
|
||||
)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class RetrievalStrategySwitchResponse(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-02] Response after strategy switch.
|
||||
"""
|
||||
|
||||
previous: RetrievalStrategyStatus = Field(..., description="Previous strategy status")
|
||||
current: RetrievalStrategyStatus = Field(..., description="Current strategy status")
|
||||
|
||||
|
||||
class RetrievalStrategyValidationRequest(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Request to validate strategy.
|
||||
"""
|
||||
|
||||
strategy: StrategyType = Field(..., description="Strategy to validate")
|
||||
react_mode: ReactMode | None = Field(
|
||||
default=None,
|
||||
description="ReAct mode to validate",
|
||||
)
|
||||
checks: list[ValidationCheckType] | None = Field(
|
||||
default=None,
|
||||
description="List of checks to perform",
|
||||
)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-06] Single validation check result.
|
||||
"""
|
||||
|
||||
check: str = Field(..., description="Check name")
|
||||
passed: bool = Field(..., description="Whether the check passed")
|
||||
message: str | None = Field(default=None, description="Additional message")
|
||||
|
||||
|
||||
class RetrievalStrategyValidationResponse(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-06] Validation response with all check results.
|
||||
"""
|
||||
|
||||
passed: bool = Field(..., description="Whether all checks passed")
|
||||
results: list[ValidationResult] = Field(..., description="Individual check results")
|
||||
|
||||
|
||||
class RetrievalStrategyRollbackResponse(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-07] Response after strategy rollback.
|
||||
"""
|
||||
|
||||
current: RetrievalStrategyStatus = Field(..., description="Current strategy status before rollback")
|
||||
rollback_to: RetrievalStrategyStatus = Field(..., description="Strategy status after rollback")
|
||||
|
||||
|
||||
class StrategyAuditLog(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-07] Audit log entry for strategy operations.
|
||||
"""
|
||||
|
||||
timestamp: str = Field(..., description="ISO timestamp of the operation")
|
||||
operation: str = Field(..., description="Operation type: switch, rollback, validate")
|
||||
previous_strategy: str | None = Field(default=None, description="Previous strategy")
|
||||
new_strategy: str | None = Field(default=None, description="New strategy")
|
||||
previous_react_mode: str | None = Field(default=None, description="Previous react mode")
|
||||
new_react_mode: str | None = Field(default=None, description="New react mode")
|
||||
reason: str | None = Field(default=None, description="Reason for the operation")
|
||||
operator: str | None = Field(default=None, description="Operator who performed the operation")
|
||||
tenant_id: str | None = Field(default=None, description="Tenant ID if applicable")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Additional metadata")
|
||||
|
||||
|
||||
class StrategyMetrics(BaseModel):
|
||||
"""
|
||||
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics for strategy operations.
|
||||
"""
|
||||
|
||||
strategy: StrategyType = Field(..., description="Current strategy")
|
||||
react_mode: ReactMode = Field(..., description="Current react mode")
|
||||
total_requests: int = Field(default=0, description="Total requests count")
|
||||
successful_requests: int = Field(default=0, description="Successful requests count")
|
||||
failed_requests: int = Field(default=0, description="Failed requests count")
|
||||
avg_latency_ms: float = Field(default=0.0, description="Average latency in ms")
|
||||
p99_latency_ms: float = Field(default=0.0, description="P99 latency in ms")
|
||||
direct_route_count: int = Field(default=0, description="Direct route count")
|
||||
react_route_count: int = Field(default=0, description="React route count")
|
||||
auto_route_count: int = Field(default=0, description="Auto route count")
|
||||
fallback_count: int = Field(default=0, description="Fallback to default count")
|
||||
last_updated: str | None = Field(default=None, description="Last update timestamp")
|
||||
|
|
@ -81,7 +81,6 @@ class ApiKeyService:
|
|||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Backward-compat fallback for environments without new columns
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Scene Slot Bundle Cache Service.
|
||||
[AC-SCENE-SLOT-03] 场景槽位包缓存服务
|
||||
|
||||
职责:
|
||||
1. 缓存场景槽位包配置,减少数据库查询
|
||||
2. 支持缓存失效和刷新
|
||||
3. 支持租户隔离
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSceneSlotBundle:
|
||||
"""缓存的场景槽位包"""
|
||||
scene_key: str
|
||||
scene_name: str
|
||||
description: str | None
|
||||
required_slots: list[str]
|
||||
optional_slots: list[str]
|
||||
slot_priority: list[str] | None
|
||||
completion_threshold: float
|
||||
ask_back_order: str
|
||||
status: str
|
||||
version: int
|
||||
cached_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"scene_key": self.scene_key,
|
||||
"scene_name": self.scene_name,
|
||||
"description": self.description,
|
||||
"required_slots": self.required_slots,
|
||||
"optional_slots": self.optional_slots,
|
||||
"slot_priority": self.slot_priority,
|
||||
"completion_threshold": self.completion_threshold,
|
||||
"ask_back_order": self.ask_back_order,
|
||||
"status": self.status,
|
||||
"version": self.version,
|
||||
"cached_at": self.cached_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CachedSceneSlotBundle":
|
||||
return cls(
|
||||
scene_key=data["scene_key"],
|
||||
scene_name=data["scene_name"],
|
||||
description=data.get("description"),
|
||||
required_slots=data.get("required_slots", []),
|
||||
optional_slots=data.get("optional_slots", []),
|
||||
slot_priority=data.get("slot_priority"),
|
||||
completion_threshold=data.get("completion_threshold", 1.0),
|
||||
ask_back_order=data.get("ask_back_order", "priority"),
|
||||
status=data.get("status", "draft"),
|
||||
version=data.get("version", 1),
|
||||
cached_at=datetime.fromisoformat(data["cached_at"]) if data.get("cached_at") else datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class SceneSlotBundleCache:
|
||||
"""
|
||||
[AC-SCENE-SLOT-03] 场景槽位包缓存
|
||||
|
||||
使用 Redis 或内存缓存场景槽位包配置
|
||||
"""
|
||||
|
||||
CACHE_PREFIX = "scene_slot_bundle"
|
||||
CACHE_TTL_SECONDS = 300 # 5分钟缓存
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any = None,
|
||||
ttl_seconds: int = 300,
|
||||
):
|
||||
self._redis = redis_client
|
||||
self._ttl = ttl_seconds or self.CACHE_TTL_SECONDS
|
||||
self._memory_cache: dict[str, tuple[CachedSceneSlotBundle, datetime]] = {}
|
||||
|
||||
def _get_cache_key(self, tenant_id: str, scene_key: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"{self.CACHE_PREFIX}:{tenant_id}:{scene_key}"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> CachedSceneSlotBundle | None:
|
||||
"""
|
||||
获取缓存的场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
缓存的场景槽位包或 None
|
||||
"""
|
||||
cache_key = self._get_cache_key(tenant_id, scene_key)
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
cached_data = await self._redis.get(cache_key)
|
||||
if cached_data:
|
||||
data = json.loads(cached_data)
|
||||
return CachedSceneSlotBundle.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis get failed: {e}, falling back to memory cache"
|
||||
)
|
||||
|
||||
if cache_key in self._memory_cache:
|
||||
cached_bundle, cached_at = self._memory_cache[cache_key]
|
||||
if datetime.utcnow() - cached_at < timedelta(seconds=self._ttl):
|
||||
return cached_bundle
|
||||
else:
|
||||
del self._memory_cache[cache_key]
|
||||
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
bundle: CachedSceneSlotBundle,
|
||||
) -> bool:
|
||||
"""
|
||||
设置缓存
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
bundle: 要缓存的场景槽位包
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
cache_key = self._get_cache_key(tenant_id, scene_key)
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
await self._redis.setex(
|
||||
cache_key,
|
||||
self._ttl,
|
||||
json.dumps(bundle.to_dict()),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis set failed: {e}, falling back to memory cache"
|
||||
)
|
||||
|
||||
self._memory_cache[cache_key] = (bundle, datetime.utcnow())
|
||||
return True
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
删除缓存
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
cache_key = self._get_cache_key(tenant_id, scene_key)
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
await self._redis.delete(cache_key)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis delete failed: {e}"
|
||||
)
|
||||
|
||||
if cache_key in self._memory_cache:
|
||||
del self._memory_cache[cache_key]
|
||||
|
||||
return True
|
||||
|
||||
async def delete_by_tenant(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
删除租户下所有缓存
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
pattern = f"{self.CACHE_PREFIX}:{tenant_id}:*"
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
keys = []
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis delete by tenant failed: {e}"
|
||||
)
|
||||
|
||||
keys_to_delete = [
|
||||
k for k in self._memory_cache
|
||||
if k.startswith(f"{self.CACHE_PREFIX}:{tenant_id}:")
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._memory_cache[key]
|
||||
|
||||
return True
|
||||
|
||||
async def invalidate_on_update(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
当场景槽位包更新时使缓存失效
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-03] Invalidating cache for scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene={scene_key}"
|
||||
)
|
||||
return await self.delete(tenant_id, scene_key)
|
||||
|
||||
|
||||
_scene_slot_bundle_cache: SceneSlotBundleCache | None = None
|
||||
|
||||
|
||||
def get_scene_slot_bundle_cache() -> SceneSlotBundleCache:
|
||||
"""获取场景槽位包缓存实例"""
|
||||
global _scene_slot_bundle_cache
|
||||
if _scene_slot_bundle_cache is None:
|
||||
_scene_slot_bundle_cache = SceneSlotBundleCache()
|
||||
return _scene_slot_bundle_cache
|
||||
|
||||
|
||||
def init_scene_slot_bundle_cache(redis_client: Any = None, ttl_seconds: int = 300) -> None:
|
||||
"""初始化场景槽位包缓存"""
|
||||
global _scene_slot_bundle_cache
|
||||
_scene_slot_bundle_cache = SceneSlotBundleCache(
|
||||
redis_client=redis_client,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-03] Scene slot bundle cache initialized: "
|
||||
f"ttl={ttl_seconds}s, redis={redis_client is not None}"
|
||||
)
|
||||
|
|
@ -0,0 +1,397 @@
|
|||
"""
|
||||
Slot State Cache Layer.
|
||||
槽位状态缓存层 - 提供会话级槽位状态持久化
|
||||
|
||||
[AC-MRS-SLOT-CACHE-01] 多轮状态持久化
|
||||
|
||||
Features:
|
||||
- L1: In-memory cache (process-level, 5 min TTL)
|
||||
- L2: Redis cache (shared, configurable TTL)
|
||||
- Automatic fallback on cache miss
|
||||
- Support for slot value source priority
|
||||
|
||||
Key format: slot_state:{tenant_id}:{session_id}
|
||||
TTL: Configurable (default 30 minutes)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.mid.schemas import SlotSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSlotValue:
|
||||
"""
|
||||
缓存的槽位值
|
||||
|
||||
Attributes:
|
||||
value: 槽位值
|
||||
source: 值来源 (user_confirmed, rule_extracted, llm_inferred, default, context)
|
||||
confidence: 置信度
|
||||
updated_at: 更新时间戳
|
||||
"""
|
||||
value: Any
|
||||
source: str
|
||||
confidence: float = 1.0
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"value": self.value,
|
||||
"source": self.source,
|
||||
"confidence": self.confidence,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CachedSlotValue":
|
||||
return cls(
|
||||
value=data["value"],
|
||||
source=data["source"],
|
||||
confidence=data.get("confidence", 1.0),
|
||||
updated_at=data.get("updated_at", time.time()),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSlotState:
|
||||
"""
|
||||
缓存的槽位状态
|
||||
|
||||
Attributes:
|
||||
filled_slots: 已填充的槽位值字典 {slot_key: CachedSlotValue}
|
||||
slot_to_field_map: 槽位到元数据字段的映射
|
||||
created_at: 创建时间
|
||||
updated_at: 最后更新时间
|
||||
"""
|
||||
filled_slots: dict[str, CachedSlotValue] = field(default_factory=dict)
|
||||
slot_to_field_map: dict[str, str] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"filled_slots": {
|
||||
k: v.to_dict() for k, v in self.filled_slots.items()
|
||||
},
|
||||
"slot_to_field_map": self.slot_to_field_map,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CachedSlotState":
|
||||
filled_slots = {}
|
||||
for k, v in data.get("filled_slots", {}).items():
|
||||
if isinstance(v, dict):
|
||||
filled_slots[k] = CachedSlotValue.from_dict(v)
|
||||
else:
|
||||
filled_slots[k] = CachedSlotValue(value=v, source="unknown")
|
||||
|
||||
return cls(
|
||||
filled_slots=filled_slots,
|
||||
slot_to_field_map=data.get("slot_to_field_map", {}),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
updated_at=data.get("updated_at", time.time()),
|
||||
)
|
||||
|
||||
def get_simple_filled_slots(self) -> dict[str, Any]:
|
||||
"""获取简化的已填充槽位字典(仅值)"""
|
||||
return {k: v.value for k, v in self.filled_slots.items()}
|
||||
|
||||
def get_slot_sources(self) -> dict[str, str]:
|
||||
"""获取槽位来源字典"""
|
||||
return {k: v.source for k, v in self.filled_slots.items()}
|
||||
|
||||
def get_slot_confidence(self) -> dict[str, float]:
|
||||
"""获取槽位置信度字典"""
|
||||
return {k: v.confidence for k, v in self.filled_slots.items()}
|
||||
|
||||
|
||||
class SlotStateCache:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 槽位状态缓存层
|
||||
|
||||
提供会话级槽位状态持久化,支持:
|
||||
- L1: 内存缓存(进程级,5分钟 TTL)
|
||||
- L2: Redis 缓存(共享,可配置 TTL)
|
||||
- 自动降级(Redis 不可用时仅使用内存缓存)
|
||||
- 槽位值来源优先级合并
|
||||
|
||||
Key format: slot_state:{tenant_id}:{session_id}
|
||||
TTL: Configurable via settings.slot_state_cache_ttl (default 1800 seconds = 30 minutes)
|
||||
"""
|
||||
|
||||
_local_cache: dict[str, tuple[CachedSlotState, float]] = {}
|
||||
_local_cache_ttl = 300
|
||||
|
||||
SOURCE_PRIORITY = {
|
||||
SlotSource.USER_CONFIRMED.value: 100,
|
||||
"user_confirmed": 100,
|
||||
SlotSource.RULE_EXTRACTED.value: 80,
|
||||
"rule_extracted": 80,
|
||||
SlotSource.LLM_INFERRED.value: 60,
|
||||
"llm_inferred": 60,
|
||||
"context": 40,
|
||||
SlotSource.DEFAULT.value: 20,
|
||||
"default": 20,
|
||||
"unknown": 0,
|
||||
}
|
||||
|
||||
def __init__(self, redis_client: redis.Redis | None = None):
|
||||
self._redis = redis_client
|
||||
self._settings = get_settings()
|
||||
self._enabled = self._settings.redis_enabled
|
||||
self._cache_ttl = getattr(self._settings, "slot_state_cache_ttl", 1800)
|
||||
|
||||
async def _get_client(self) -> redis.Redis | None:
|
||||
"""Get or create Redis client."""
|
||||
if not self._enabled:
|
||||
return None
|
||||
if self._redis is None:
|
||||
try:
|
||||
self._redis = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to connect to Redis: {e}")
|
||||
self._enabled = False
|
||||
return None
|
||||
return self._redis
|
||||
|
||||
def _make_key(self, tenant_id: str, session_id: str) -> str:
|
||||
"""Generate cache key."""
|
||||
return f"slot_state:{tenant_id}:{session_id}"
|
||||
|
||||
def _make_local_key(self, tenant_id: str, session_id: str) -> str:
|
||||
"""Generate local cache key."""
|
||||
return f"{tenant_id}:{session_id}"
|
||||
|
||||
def _get_source_priority(self, source: str) -> int:
|
||||
"""Get priority for a source."""
|
||||
return self.SOURCE_PRIORITY.get(source, 0)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
) -> CachedSlotState | None:
|
||||
"""
|
||||
Get cached slot state (L1 -> L2).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for isolation
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
CachedSlotState or None if not found
|
||||
"""
|
||||
local_key = self._make_local_key(tenant_id, session_id)
|
||||
if local_key in self._local_cache:
|
||||
state, timestamp = self._local_cache[local_key]
|
||||
if time.time() - timestamp < self._local_cache_ttl:
|
||||
logger.debug(f"[SlotStateCache] L1 hit: {local_key}")
|
||||
return state
|
||||
else:
|
||||
del self._local_cache[local_key]
|
||||
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
key = self._make_key(tenant_id, session_id)
|
||||
|
||||
try:
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"[SlotStateCache] L2 hit: {key}")
|
||||
state_dict = json.loads(data)
|
||||
state = CachedSlotState.from_dict(state_dict)
|
||||
|
||||
self._local_cache[local_key] = (state, time.time())
|
||||
|
||||
return state
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to get from cache: {e}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
state: CachedSlotState,
|
||||
) -> bool:
|
||||
"""
|
||||
Set slot state to cache (L1 + L2).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for isolation
|
||||
session_id: Session ID
|
||||
state: CachedSlotState to cache
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
local_key = self._make_local_key(tenant_id, session_id)
|
||||
state.updated_at = time.time()
|
||||
self._local_cache[local_key] = (state, time.time())
|
||||
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id, session_id)
|
||||
|
||||
try:
|
||||
await client.setex(
|
||||
key,
|
||||
self._cache_ttl,
|
||||
json.dumps(state.to_dict(), default=str),
|
||||
)
|
||||
logger.debug(f"[SlotStateCache] Set cache: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to set cache: {e}")
|
||||
return False
|
||||
|
||||
async def merge_and_set(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
new_slots: dict[str, CachedSlotValue],
|
||||
slot_to_field_map: dict[str, str] | None = None,
|
||||
) -> CachedSlotState:
|
||||
"""
|
||||
Merge new slot values with cached state and save.
|
||||
|
||||
Priority: user_confirmed > rule_extracted > llm_inferred > context > default
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
session_id: Session ID
|
||||
new_slots: New slot values to merge
|
||||
slot_to_field_map: Slot to field mapping
|
||||
|
||||
Returns:
|
||||
Updated CachedSlotState
|
||||
"""
|
||||
state = await self.get(tenant_id, session_id)
|
||||
if state is None:
|
||||
state = CachedSlotState()
|
||||
|
||||
for slot_key, new_value in new_slots.items():
|
||||
if slot_key in state.filled_slots:
|
||||
existing = state.filled_slots[slot_key]
|
||||
existing_priority = self._get_source_priority(existing.source)
|
||||
new_priority = self._get_source_priority(new_value.source)
|
||||
|
||||
if new_priority >= existing_priority:
|
||||
state.filled_slots[slot_key] = new_value
|
||||
logger.debug(
|
||||
f"[SlotStateCache] Slot '{slot_key}' updated: "
|
||||
f"{existing.source}({existing_priority}) -> "
|
||||
f"{new_value.source}({new_priority})"
|
||||
)
|
||||
else:
|
||||
state.filled_slots[slot_key] = new_value
|
||||
logger.debug(
|
||||
f"[SlotStateCache] Slot '{slot_key}' added: "
|
||||
f"source={new_value.source}, value={new_value.value}"
|
||||
)
|
||||
|
||||
if slot_to_field_map:
|
||||
state.slot_to_field_map.update(slot_to_field_map)
|
||||
|
||||
await self.set(tenant_id, session_id, state)
|
||||
|
||||
return state
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete slot state from cache (L1 + L2).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for isolation
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
local_key = self._make_local_key(tenant_id, session_id)
|
||||
if local_key in self._local_cache:
|
||||
del self._local_cache[local_key]
|
||||
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id, session_id)
|
||||
|
||||
try:
|
||||
await client.delete(key)
|
||||
logger.debug(f"[SlotStateCache] Deleted cache: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to delete cache: {e}")
|
||||
return False
|
||||
|
||||
async def clear_slot(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
slot_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Clear a specific slot from cached state.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
session_id: Session ID
|
||||
slot_key: Slot key to clear
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
state = await self.get(tenant_id, session_id)
|
||||
if state is None:
|
||||
return True
|
||||
|
||||
if slot_key in state.filled_slots:
|
||||
del state.filled_slots[slot_key]
|
||||
await self.set(tenant_id, session_id, state)
|
||||
logger.debug(f"[SlotStateCache] Cleared slot: {slot_key}")
|
||||
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
|
||||
|
||||
_slot_state_cache: SlotStateCache | None = None
|
||||
|
||||
|
||||
def get_slot_state_cache() -> SlotStateCache:
|
||||
"""Get singleton SlotStateCache instance."""
|
||||
global _slot_state_cache
|
||||
if _slot_state_cache is None:
|
||||
_slot_state_cache = SlotStateCache()
|
||||
return _slot_state_cache
|
||||
|
|
@ -12,6 +12,9 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||
|
|
@ -20,6 +23,7 @@ from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
|
||||
EMBEDDING_CONFIG_REDIS_KEY = "ai_service:config:embedding"
|
||||
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
|
|
@ -170,8 +174,32 @@ class EmbeddingConfigManager:
|
|||
self._config = self._default_config.copy()
|
||||
self._provider: EmbeddingProvider | None = None
|
||||
|
||||
self._settings = get_settings()
|
||||
self._redis_client: redis.Redis | None = None
|
||||
|
||||
self._load_from_redis()
|
||||
self._load_from_file()
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(EMBEDDING_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._provider_name = saved.get("provider", self._default_provider)
|
||||
self._config = saved.get("config", self._default_config.copy())
|
||||
logger.info(f"Loaded embedding config from Redis: provider={self._provider_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load embedding config from Redis: {e}")
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
|
|
@ -184,6 +212,28 @@ class EmbeddingConfigManager:
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to load embedding config from file: {e}")
|
||||
|
||||
def _save_to_redis(self) -> None:
|
||||
"""Save configuration to Redis."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
if self._redis_client is None:
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
self._redis_client.set(
|
||||
EMBEDDING_CONFIG_REDIS_KEY,
|
||||
json.dumps({
|
||||
"provider": self._provider_name,
|
||||
"config": self._config,
|
||||
}, ensure_ascii=False),
|
||||
)
|
||||
logger.info(f"Saved embedding config to Redis: provider={self._provider_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save embedding config to Redis: {e}")
|
||||
|
||||
def _save_to_file(self) -> None:
|
||||
"""Save configuration to file."""
|
||||
try:
|
||||
|
|
@ -262,6 +312,7 @@ class EmbeddingConfigManager:
|
|||
self._config = config
|
||||
self._provider = new_provider_instance
|
||||
|
||||
self._save_to_redis()
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"Updated embedding config: provider={provider}")
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class FlowEngine:
|
|||
stmt = select(FlowInstance).where(
|
||||
FlowInstance.tenant_id == tenant_id,
|
||||
FlowInstance.session_id == session_id,
|
||||
).order_by(col(FlowInstance.created_at).desc())
|
||||
).order_by(col(FlowInstance.started_at).desc())
|
||||
result = await self._session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,385 @@
|
|||
"""
|
||||
Clarification mechanism for intent recognition.
|
||||
[AC-CLARIFY] 澄清机制实现
|
||||
|
||||
核心功能:
|
||||
1. 统一置信度计算
|
||||
2. 硬拦截规则(confidence检查、required_slots检查)
|
||||
3. 澄清状态管理
|
||||
4. 埋点指标收集
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
T_HIGH = 0.75
|
||||
T_LOW = 0.45
|
||||
MAX_CLARIFY_RETRY = 3
|
||||
|
||||
|
||||
class ClarifyReason(str, Enum):
|
||||
INTENT_AMBIGUITY = "intent_ambiguity"
|
||||
MISSING_SLOT = "missing_slot"
|
||||
LOW_CONFIDENCE = "low_confidence"
|
||||
MULTI_INTENT = "multi_intent"
|
||||
|
||||
|
||||
class ClarifyMetrics:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._clarify_trigger_count = 0
|
||||
cls._instance._clarify_converge_count = 0
|
||||
cls._instance._misroute_count = 0
|
||||
return cls._instance
|
||||
|
||||
def record_clarify_trigger(self) -> None:
|
||||
self._clarify_trigger_count += 1
|
||||
logger.debug(f"[AC-CLARIFY-METRICS] clarify_trigger_count: {self._clarify_trigger_count}")
|
||||
|
||||
def record_clarify_converge(self) -> None:
|
||||
self._clarify_converge_count += 1
|
||||
logger.debug(f"[AC-CLARIFY-METRICS] clarify_converge_count: {self._clarify_converge_count}")
|
||||
|
||||
def record_misroute(self) -> None:
|
||||
self._misroute_count += 1
|
||||
logger.debug(f"[AC-CLARIFY-METRICS] misroute_count: {self._misroute_count}")
|
||||
|
||||
def get_metrics(self) -> dict[str, int]:
|
||||
return {
|
||||
"clarify_trigger_rate": self._clarify_trigger_count,
|
||||
"clarify_converge_rate": self._clarify_converge_count,
|
||||
"misroute_rate": self._misroute_count,
|
||||
}
|
||||
|
||||
def get_rates(self, total_requests: int) -> dict[str, float]:
|
||||
if total_requests == 0:
|
||||
return {
|
||||
"clarify_trigger_rate": 0.0,
|
||||
"clarify_converge_rate": 0.0,
|
||||
"misroute_rate": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"clarify_trigger_rate": self._clarify_trigger_count / total_requests,
|
||||
"clarify_converge_rate": self._clarify_converge_count / total_requests if self._clarify_trigger_count > 0 else 0.0,
|
||||
"misroute_rate": self._misroute_count / total_requests,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self._clarify_trigger_count = 0
|
||||
self._clarify_converge_count = 0
|
||||
self._misroute_count = 0
|
||||
|
||||
|
||||
def get_clarify_metrics() -> ClarifyMetrics:
|
||||
return ClarifyMetrics()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntentCandidate:
|
||||
intent_id: str
|
||||
intent_name: str
|
||||
confidence: float
|
||||
response_type: str | None = None
|
||||
target_kb_ids: list[str] | None = None
|
||||
flow_id: str | None = None
|
||||
fixed_reply: str | None = None
|
||||
transfer_message: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"intent_id": self.intent_id,
|
||||
"intent_name": self.intent_name,
|
||||
"confidence": self.confidence,
|
||||
"response_type": self.response_type,
|
||||
"target_kb_ids": self.target_kb_ids,
|
||||
"flow_id": self.flow_id,
|
||||
"fixed_reply": self.fixed_reply,
|
||||
"transfer_message": self.transfer_message,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridIntentResult:
|
||||
intent: IntentCandidate | None
|
||||
confidence: float
|
||||
candidates: list[IntentCandidate] = field(default_factory=list)
|
||||
need_clarify: bool = False
|
||||
clarify_reason: ClarifyReason | None = None
|
||||
missing_slots: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"intent": self.intent.to_dict() if self.intent else None,
|
||||
"confidence": self.confidence,
|
||||
"candidates": [c.to_dict() for c in self.candidates],
|
||||
"need_clarify": self.need_clarify,
|
||||
"clarify_reason": self.clarify_reason.value if self.clarify_reason else None,
|
||||
"missing_slots": self.missing_slots,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_fusion_result(cls, fusion_result: Any) -> "HybridIntentResult":
|
||||
candidates = []
|
||||
if fusion_result.clarify_candidates:
|
||||
for c in fusion_result.clarify_candidates:
|
||||
candidates.append(IntentCandidate(
|
||||
intent_id=str(c.id),
|
||||
intent_name=c.name,
|
||||
confidence=0.0,
|
||||
response_type=getattr(c, "response_type", None),
|
||||
target_kb_ids=getattr(c, "target_kb_ids", None),
|
||||
flow_id=str(c.flow_id) if getattr(c, "flow_id", None) else None,
|
||||
fixed_reply=getattr(c, "fixed_reply", None),
|
||||
transfer_message=getattr(c, "transfer_message", None),
|
||||
))
|
||||
|
||||
if fusion_result.final_intent:
|
||||
final_candidate = IntentCandidate(
|
||||
intent_id=str(fusion_result.final_intent.id),
|
||||
intent_name=fusion_result.final_intent.name,
|
||||
confidence=fusion_result.final_confidence,
|
||||
response_type=fusion_result.final_intent.response_type,
|
||||
target_kb_ids=fusion_result.final_intent.target_kb_ids,
|
||||
flow_id=str(fusion_result.final_intent.flow_id) if fusion_result.final_intent.flow_id else None,
|
||||
fixed_reply=fusion_result.final_intent.fixed_reply,
|
||||
transfer_message=fusion_result.final_intent.transfer_message,
|
||||
)
|
||||
if not any(c.intent_id == final_candidate.intent_id for c in candidates):
|
||||
candidates.insert(0, final_candidate)
|
||||
|
||||
clarify_reason = None
|
||||
if fusion_result.need_clarify:
|
||||
if fusion_result.decision_reason == "multi_intent":
|
||||
clarify_reason = ClarifyReason.MULTI_INTENT
|
||||
elif fusion_result.decision_reason == "gray_zone":
|
||||
clarify_reason = ClarifyReason.INTENT_AMBIGUITY
|
||||
else:
|
||||
clarify_reason = ClarifyReason.LOW_CONFIDENCE
|
||||
|
||||
return cls(
|
||||
intent=candidates[0] if candidates else None,
|
||||
confidence=fusion_result.final_confidence,
|
||||
candidates=candidates,
|
||||
need_clarify=fusion_result.need_clarify,
|
||||
clarify_reason=clarify_reason,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClarifyState:
|
||||
reason: ClarifyReason
|
||||
asked_slot: str | None = None
|
||||
retry_count: int = 0
|
||||
candidates: list[IntentCandidate] = field(default_factory=list)
|
||||
asked_intent_ids: list[str] = field(default_factory=list)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"reason": self.reason.value,
|
||||
"asked_slot": self.asked_slot,
|
||||
"retry_count": self.retry_count,
|
||||
"candidates": [c.to_dict() for c in self.candidates],
|
||||
"asked_intent_ids": self.asked_intent_ids,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
def increment_retry(self) -> "ClarifyState":
|
||||
self.retry_count += 1
|
||||
return self
|
||||
|
||||
def is_max_retry(self) -> bool:
|
||||
return self.retry_count >= MAX_CLARIFY_RETRY
|
||||
|
||||
|
||||
class ClarificationEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t_high: float = T_HIGH,
|
||||
t_low: float = T_LOW,
|
||||
max_retry: int = MAX_CLARIFY_RETRY,
|
||||
):
|
||||
self._t_high = t_high
|
||||
self._t_low = t_low
|
||||
self._max_retry = max_retry
|
||||
self._metrics = get_clarify_metrics()
|
||||
|
||||
def compute_confidence(
|
||||
self,
|
||||
rule_score: float = 0.0,
|
||||
semantic_score: float = 0.0,
|
||||
llm_score: float = 0.0,
|
||||
w_rule: float = 0.5,
|
||||
w_semantic: float = 0.3,
|
||||
w_llm: float = 0.2,
|
||||
) -> float:
|
||||
total_weight = w_rule + w_semantic + w_llm
|
||||
if total_weight == 0:
|
||||
return 0.0
|
||||
|
||||
weighted_score = (
|
||||
rule_score * w_rule +
|
||||
semantic_score * w_semantic +
|
||||
llm_score * w_llm
|
||||
)
|
||||
|
||||
return min(1.0, max(0.0, weighted_score / total_weight))
|
||||
|
||||
def check_hard_block(
|
||||
self,
|
||||
result: HybridIntentResult,
|
||||
required_slots: list[str] | None = None,
|
||||
filled_slots: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, ClarifyReason | None]:
|
||||
if result.confidence < self._t_high:
|
||||
return True, ClarifyReason.LOW_CONFIDENCE
|
||||
|
||||
if required_slots and filled_slots is not None:
|
||||
missing = [s for s in required_slots if s not in filled_slots]
|
||||
if missing:
|
||||
return True, ClarifyReason.MISSING_SLOT
|
||||
|
||||
return False, None
|
||||
|
||||
def should_trigger_clarify(
|
||||
self,
|
||||
result: HybridIntentResult,
|
||||
required_slots: list[str] | None = None,
|
||||
filled_slots: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, ClarifyState | None]:
|
||||
if result.confidence >= self._t_high:
|
||||
if required_slots and filled_slots is not None:
|
||||
missing = [s for s in required_slots if s not in filled_slots]
|
||||
if missing:
|
||||
self._metrics.record_clarify_trigger()
|
||||
return True, ClarifyState(
|
||||
reason=ClarifyReason.MISSING_SLOT,
|
||||
asked_slot=missing[0],
|
||||
candidates=result.candidates,
|
||||
)
|
||||
return False, None
|
||||
|
||||
if result.confidence < self._t_low:
|
||||
self._metrics.record_clarify_trigger()
|
||||
return True, ClarifyState(
|
||||
reason=ClarifyReason.LOW_CONFIDENCE,
|
||||
candidates=result.candidates,
|
||||
)
|
||||
|
||||
self._metrics.record_clarify_trigger()
|
||||
|
||||
reason = result.clarify_reason or ClarifyReason.INTENT_AMBIGUITY
|
||||
return True, ClarifyState(
|
||||
reason=reason,
|
||||
candidates=result.candidates,
|
||||
)
|
||||
|
||||
def generate_clarify_prompt(
|
||||
self,
|
||||
state: ClarifyState,
|
||||
slot_label: str | None = None,
|
||||
) -> str:
|
||||
if state.reason == ClarifyReason.MISSING_SLOT:
|
||||
slot_name = slot_label or state.asked_slot or "相关信息"
|
||||
return f"为了更好地为您服务,请告诉我您的{slot_name}。"
|
||||
|
||||
if state.reason == ClarifyReason.LOW_CONFIDENCE:
|
||||
return "抱歉,我不太理解您的意思,能否请您详细描述一下您的需求?"
|
||||
|
||||
if state.reason == ClarifyReason.MULTI_INTENT and len(state.candidates) > 1:
|
||||
candidates = state.candidates[:3]
|
||||
if len(candidates) == 2:
|
||||
return (
|
||||
f"请问您是想「{candidates[0].intent_name}」"
|
||||
f"还是「{candidates[1].intent_name}」?"
|
||||
)
|
||||
else:
|
||||
options = "、".join([f"「{c.intent_name}」" for c in candidates[:-1]])
|
||||
return f"请问您是想{options},还是「{candidates[-1].intent_name}」?"
|
||||
|
||||
if state.reason == ClarifyReason.INTENT_AMBIGUITY and len(state.candidates) > 1:
|
||||
candidates = state.candidates[:2]
|
||||
return (
|
||||
f"请问您是想「{candidates[0].intent_name}」"
|
||||
f"还是「{candidates[1].intent_name}」?"
|
||||
)
|
||||
|
||||
return "请问您具体想了解什么?"
|
||||
|
||||
def process_clarify_response(
|
||||
self,
|
||||
user_message: str,
|
||||
state: ClarifyState,
|
||||
intent_router: Any = None,
|
||||
rules: list[Any] | None = None,
|
||||
) -> HybridIntentResult:
|
||||
state.increment_retry()
|
||||
|
||||
if state.is_max_retry():
|
||||
self._metrics.record_misroute()
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
need_clarify=False,
|
||||
)
|
||||
|
||||
if state.reason == ClarifyReason.MISSING_SLOT:
|
||||
self._metrics.record_clarify_converge()
|
||||
return HybridIntentResult(
|
||||
intent=state.candidates[0] if state.candidates else None,
|
||||
confidence=0.8,
|
||||
candidates=state.candidates,
|
||||
need_clarify=False,
|
||||
)
|
||||
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=state.candidates,
|
||||
need_clarify=True,
|
||||
clarify_reason=state.reason,
|
||||
)
|
||||
|
||||
def get_metrics(self) -> dict[str, int]:
|
||||
return self._metrics.get_metrics()
|
||||
|
||||
def get_rates(self, total_requests: int) -> dict[str, float]:
|
||||
return self._metrics.get_rates(total_requests)
|
||||
|
||||
|
||||
class ClarifySessionManager:
|
||||
_sessions: dict[str, ClarifyState] = {}
|
||||
|
||||
@classmethod
|
||||
def get_session(cls, session_id: str) -> ClarifyState | None:
|
||||
return cls._sessions.get(session_id)
|
||||
|
||||
@classmethod
|
||||
def set_session(cls, session_id: str, state: ClarifyState) -> None:
|
||||
cls._sessions[session_id] = state
|
||||
logger.debug(f"[AC-CLARIFY] Session state set: session={session_id}, reason={state.reason}")
|
||||
|
||||
@classmethod
|
||||
def clear_session(cls, session_id: str) -> None:
|
||||
if session_id in cls._sessions:
|
||||
del cls._sessions[session_id]
|
||||
logger.debug(f"[AC-CLARIFY] Session state cleared: session={session_id}")
|
||||
|
||||
@classmethod
|
||||
def has_active_clarify(cls, session_id: str) -> bool:
|
||||
state = cls._sessions.get(session_id)
|
||||
if state:
|
||||
return not state.is_max_retry()
|
||||
return False
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
"""
|
||||
[v0.8.0] Fusion policy for hybrid intent routing.
|
||||
[AC-AISVC-115~AC-AISVC-117] Fusion decision logic for three-way matching.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
FusionResult,
|
||||
LlmJudgeResult,
|
||||
RouteTrace,
|
||||
RuleMatchResult,
|
||||
SemanticMatchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.entities import IntentRule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DecisionCondition = Callable[[RuleMatchResult, SemanticMatchResult, LlmJudgeResult], bool]
|
||||
|
||||
|
||||
class FusionPolicy:
|
||||
"""
|
||||
[AC-AISVC-115] Fusion decision policy for hybrid routing.
|
||||
|
||||
Decision priority:
|
||||
1. rule_high_confidence: RuleMatcher hit with score=1.0
|
||||
2. llm_judge: LlmJudge triggered and returned valid intent
|
||||
3. semantic_override: RuleMatcher missed but SemanticMatcher high confidence
|
||||
4. rule_semantic_agree: Both match same intent
|
||||
5. semantic_fallback: SemanticMatcher medium confidence
|
||||
6. rule_fallback: Only rule matched
|
||||
7. no_match: All low confidence
|
||||
"""
|
||||
|
||||
DECISION_PRIORITY: list[tuple[str, DecisionCondition]] = [
|
||||
("rule_high_confidence", lambda r, s, llm: r.score == 1.0 and r.rule is not None),
|
||||
("llm_judge", lambda r, s, llm: llm.triggered and llm.intent_id is not None),
|
||||
(
|
||||
"semantic_override",
|
||||
lambda r, s, llm: r.score == 0
|
||||
and s.top_score > 0.7
|
||||
and not s.skipped
|
||||
and len(s.candidates) > 0,
|
||||
),
|
||||
(
|
||||
"rule_semantic_agree",
|
||||
lambda r, s, llm: r.score > 0
|
||||
and s.top_score > 0.5
|
||||
and not s.skipped
|
||||
and len(s.candidates) > 0
|
||||
and r.rule_id == s.candidates[0].rule.id,
|
||||
),
|
||||
(
|
||||
"semantic_fallback",
|
||||
lambda r, s, llm: s.top_score > 0.5 and not s.skipped and len(s.candidates) > 0,
|
||||
),
|
||||
("rule_fallback", lambda r, s, llm: r.score > 0),
|
||||
("no_match", lambda r, s, llm: True),
|
||||
]
|
||||
|
||||
def __init__(self, config: FusionConfig | None = None):
|
||||
"""
|
||||
Initialize fusion policy with configuration.
|
||||
|
||||
Args:
|
||||
config: Fusion configuration, uses default if not provided
|
||||
"""
|
||||
self._config = config or FusionConfig()
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
llm_result: LlmJudgeResult | None,
|
||||
) -> FusionResult:
|
||||
"""
|
||||
[AC-AISVC-115] Execute fusion decision.
|
||||
|
||||
Args:
|
||||
rule_result: Rule matching result
|
||||
semantic_result: Semantic matching result
|
||||
llm_result: LLM judge result (may be None)
|
||||
|
||||
Returns:
|
||||
FusionResult with final intent, confidence, and trace
|
||||
"""
|
||||
trace = self._build_trace(rule_result, semantic_result, llm_result)
|
||||
|
||||
final_intent = None
|
||||
final_confidence = 0.0
|
||||
decision_reason = "no_match"
|
||||
|
||||
effective_llm_result = llm_result or LlmJudgeResult.empty()
|
||||
|
||||
for reason, condition in self.DECISION_PRIORITY:
|
||||
if condition(rule_result, semantic_result, effective_llm_result):
|
||||
decision_reason = reason
|
||||
break
|
||||
|
||||
if decision_reason == "rule_high_confidence":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = 1.0
|
||||
elif decision_reason == "llm_judge" and llm_result:
|
||||
final_intent = self._find_rule_by_id(
|
||||
llm_result.intent_id, rule_result, semantic_result
|
||||
)
|
||||
final_confidence = llm_result.score
|
||||
elif decision_reason == "semantic_override":
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
elif decision_reason == "rule_semantic_agree":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = self._calculate_weighted_confidence(
|
||||
rule_result, semantic_result, llm_result
|
||||
)
|
||||
elif decision_reason == "semantic_fallback":
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
elif decision_reason == "rule_fallback":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = rule_result.score
|
||||
|
||||
need_clarify = final_confidence < self._config.clarify_threshold
|
||||
clarify_candidates = None
|
||||
if need_clarify and len(semantic_result.candidates) > 1:
|
||||
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
|
||||
|
||||
trace.fusion = {
|
||||
"weights": {
|
||||
"w_rule": self._config.w_rule,
|
||||
"w_semantic": self._config.w_semantic,
|
||||
"w_llm": self._config.w_llm,
|
||||
},
|
||||
"final_confidence": final_confidence,
|
||||
"decision_reason": decision_reason,
|
||||
"need_clarify": need_clarify,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-115] Fusion decision: reason={decision_reason}, "
|
||||
f"confidence={final_confidence:.3f}, need_clarify={need_clarify}"
|
||||
)
|
||||
|
||||
return FusionResult(
|
||||
final_intent=final_intent,
|
||||
final_confidence=final_confidence,
|
||||
decision_reason=decision_reason,
|
||||
need_clarify=need_clarify,
|
||||
clarify_candidates=clarify_candidates,
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
def _build_trace(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
llm_result: LlmJudgeResult | None,
|
||||
) -> RouteTrace:
|
||||
"""
|
||||
[AC-AISVC-122] Build route trace log.
|
||||
"""
|
||||
return RouteTrace(
|
||||
rule_match={
|
||||
"rule_id": str(rule_result.rule_id) if rule_result.rule_id else None,
|
||||
"rule_name": rule_result.rule.name if rule_result.rule else None,
|
||||
"match_type": rule_result.match_type,
|
||||
"matched_text": rule_result.matched_text,
|
||||
"score": rule_result.score,
|
||||
"duration_ms": rule_result.duration_ms,
|
||||
},
|
||||
semantic_match={
|
||||
"top_candidates": [
|
||||
{
|
||||
"rule_id": str(c.rule.id),
|
||||
"rule_name": c.rule.name,
|
||||
"score": c.score,
|
||||
}
|
||||
for c in semantic_result.candidates
|
||||
],
|
||||
"top_score": semantic_result.top_score,
|
||||
"duration_ms": semantic_result.duration_ms,
|
||||
"skipped": semantic_result.skipped,
|
||||
"skip_reason": semantic_result.skip_reason,
|
||||
},
|
||||
llm_judge={
|
||||
"triggered": llm_result.triggered if llm_result else False,
|
||||
"intent_id": llm_result.intent_id if llm_result else None,
|
||||
"intent_name": llm_result.intent_name if llm_result else None,
|
||||
"score": llm_result.score if llm_result else 0.0,
|
||||
"reasoning": llm_result.reasoning if llm_result else None,
|
||||
"duration_ms": llm_result.duration_ms if llm_result else 0,
|
||||
"tokens_used": llm_result.tokens_used if llm_result else 0,
|
||||
},
|
||||
fusion={},
|
||||
)
|
||||
|
||||
def _calculate_weighted_confidence(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
llm_result: LlmJudgeResult | None,
|
||||
) -> float:
|
||||
"""
|
||||
[AC-AISVC-116] Calculate weighted confidence.
|
||||
|
||||
Formula:
|
||||
final_confidence = (w_rule * rule_score + w_semantic * semantic_score + w_llm * llm_score) / total_weight
|
||||
|
||||
Returns:
|
||||
Weighted confidence in [0.0, 1.0]
|
||||
"""
|
||||
rule_score = rule_result.score
|
||||
semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0
|
||||
llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0
|
||||
|
||||
total_weight = self._config.w_rule + self._config.w_semantic
|
||||
if llm_result and llm_result.triggered:
|
||||
total_weight += self._config.w_llm
|
||||
|
||||
if total_weight == 0:
|
||||
return 0.0
|
||||
|
||||
confidence = (
|
||||
self._config.w_rule * rule_score
|
||||
+ self._config.w_semantic * semantic_score
|
||||
+ self._config.w_llm * llm_score
|
||||
) / total_weight
|
||||
|
||||
return min(1.0, max(0.0, confidence))
|
||||
|
||||
def _find_rule_by_id(
|
||||
self,
|
||||
intent_id: str | None,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
) -> "IntentRule | None":
|
||||
"""Find rule by ID from rule or semantic results."""
|
||||
if not intent_id:
|
||||
return None
|
||||
|
||||
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
|
||||
return rule_result.rule
|
||||
|
||||
for candidate in semantic_result.candidates:
|
||||
if str(candidate.rule.id) == intent_id:
|
||||
return candidate.rule
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
LLM judge for intent arbitration.
|
||||
[AC-AISVC-118, AC-AISVC-119] LLM-based intent arbitration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
LlmJudgeInput,
|
||||
LlmJudgeResult,
|
||||
RuleMatchResult,
|
||||
SemanticMatchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.llm.base import LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlmJudge:
|
||||
"""
|
||||
[AC-AISVC-118] LLM-based intent arbitrator.
|
||||
|
||||
Triggered when:
|
||||
- Rule vs Semantic conflict
|
||||
- Gray zone (low confidence)
|
||||
- Multiple intent candidates with similar scores
|
||||
"""
|
||||
|
||||
JUDGE_PROMPT = """你是一个意图识别仲裁器。根据用户消息和候选意图,判断最匹配的意图。
|
||||
|
||||
用户消息:{message}
|
||||
|
||||
候选意图:
|
||||
{candidates}
|
||||
|
||||
请返回 JSON 格式(不要包含```json标记):
|
||||
{{
|
||||
"intent_id": "最匹配的意图ID",
|
||||
"intent_name": "意图名称",
|
||||
"confidence": 0.0-1.0之间的置信度,
|
||||
"reasoning": "判断理由"
|
||||
}}"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: "LLMClient",
|
||||
config: FusionConfig,
|
||||
):
|
||||
"""
|
||||
Initialize LLM judge.
|
||||
|
||||
Args:
|
||||
llm_client: LLM client for generating responses
|
||||
config: Fusion configuration
|
||||
"""
|
||||
self._llm_client = llm_client
|
||||
self._config = config
|
||||
|
||||
def should_trigger(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
config: FusionConfig | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
[AC-AISVC-118] Check if LLM judge should be triggered.
|
||||
|
||||
Trigger conditions:
|
||||
1. Conflict: Rule and Semantic match different intents with close scores
|
||||
2. Gray zone: Max confidence in gray zone range
|
||||
3. Multi-intent: Multiple candidates with similar scores
|
||||
|
||||
Args:
|
||||
rule_result: Rule matching result
|
||||
semantic_result: Semantic matching result
|
||||
config: Optional config override
|
||||
|
||||
Returns:
|
||||
Tuple of (should_trigger, trigger_reason)
|
||||
"""
|
||||
effective_config = config or self._config
|
||||
|
||||
if not effective_config.llm_judge_enabled:
|
||||
return False, "disabled"
|
||||
|
||||
rule_score = rule_result.score
|
||||
semantic_score = semantic_result.top_score
|
||||
|
||||
if rule_score > 0 and semantic_score > 0:
|
||||
if semantic_result.candidates:
|
||||
top_semantic_rule_id = semantic_result.candidates[0].rule.id
|
||||
if rule_result.rule_id != top_semantic_rule_id:
|
||||
if abs(rule_score - semantic_score) < effective_config.conflict_threshold:
|
||||
logger.info(
|
||||
f"[AC-AISVC-118] LLM judge triggered: rule_semantic_conflict, "
|
||||
f"rule_id={rule_result.rule_id}, semantic_id={top_semantic_rule_id}, "
|
||||
f"rule_score={rule_score}, semantic_score={semantic_score}"
|
||||
)
|
||||
return True, "rule_semantic_conflict"
|
||||
|
||||
max_score = max(rule_score, semantic_score)
|
||||
if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold:
|
||||
logger.info(
|
||||
f"[AC-AISVC-118] LLM judge triggered: gray_zone, "
|
||||
f"max_score={max_score}"
|
||||
)
|
||||
return True, "gray_zone"
|
||||
|
||||
if len(semantic_result.candidates) >= 2:
|
||||
top1_score = semantic_result.candidates[0].score
|
||||
top2_score = semantic_result.candidates[1].score
|
||||
if abs(top1_score - top2_score) < effective_config.multi_intent_threshold:
|
||||
logger.info(
|
||||
f"[AC-AISVC-118] LLM judge triggered: multi_intent, "
|
||||
f"top1_score={top1_score}, top2_score={top2_score}"
|
||||
)
|
||||
return True, "multi_intent"
|
||||
|
||||
return False, ""
|
||||
|
||||
async def judge(
|
||||
self,
|
||||
input_data: LlmJudgeInput,
|
||||
tenant_id: str,
|
||||
) -> LlmJudgeResult:
|
||||
"""
|
||||
[AC-AISVC-119] Perform LLM arbitration.
|
||||
|
||||
Args:
|
||||
input_data: Judge input with message and candidates
|
||||
tenant_id: Tenant ID for isolation
|
||||
|
||||
Returns:
|
||||
LlmJudgeResult with arbitration decision
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
candidates_text = "\n".join([
|
||||
f"- ID: {c['id']}, 名称: {c['name']}, 描述: {c.get('description', 'N/A')}"
|
||||
for c in input_data.candidates
|
||||
])
|
||||
|
||||
prompt = self.JUDGE_PROMPT.format(
|
||||
message=input_data.message,
|
||||
candidates=candidates_text,
|
||||
)
|
||||
|
||||
try:
|
||||
from app.services.llm.base import LLMConfig
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
self._llm_client.generate(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
config=LLMConfig(
|
||||
max_tokens=200,
|
||||
temperature=0,
|
||||
),
|
||||
),
|
||||
timeout=self._config.llm_judge_timeout_ms / 1000,
|
||||
)
|
||||
|
||||
result = self._parse_response(response.content or "")
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
tokens_used = 0
|
||||
if response.usage:
|
||||
tokens_used = response.usage.get("total_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-119] LLM judge completed for tenant={tenant_id}, "
|
||||
f"intent_id={result.get('intent_id')}, confidence={result.get('confidence', 0):.3f}, "
|
||||
f"duration={duration_ms}ms, tokens={tokens_used}"
|
||||
)
|
||||
|
||||
return LlmJudgeResult(
|
||||
intent_id=result.get("intent_id"),
|
||||
intent_name=result.get("intent_name"),
|
||||
score=float(result.get("confidence", 0.5)),
|
||||
reasoning=result.get("reasoning"),
|
||||
duration_ms=duration_ms,
|
||||
tokens_used=tokens_used,
|
||||
triggered=True,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-AISVC-119] LLM judge timeout for tenant={tenant_id}, "
|
||||
f"timeout={self._config.llm_judge_timeout_ms}ms"
|
||||
)
|
||||
return LlmJudgeResult(
|
||||
intent_id=None,
|
||||
intent_name=None,
|
||||
score=0.0,
|
||||
reasoning="LLM timeout",
|
||||
duration_ms=duration_ms,
|
||||
tokens_used=0,
|
||||
triggered=True,
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-AISVC-119] LLM judge error for tenant={tenant_id}: {e}"
|
||||
)
|
||||
return LlmJudgeResult(
|
||||
intent_id=None,
|
||||
intent_name=None,
|
||||
score=0.0,
|
||||
reasoning=f"LLM error: {str(e)}",
|
||||
duration_ms=duration_ms,
|
||||
tokens_used=0,
|
||||
triggered=True,
|
||||
)
|
||||
|
||||
def _parse_response(self, content: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse LLM response to extract JSON result.
|
||||
|
||||
Args:
|
||||
content: LLM response content
|
||||
|
||||
Returns:
|
||||
Parsed dictionary with intent_id, intent_name, confidence, reasoning
|
||||
"""
|
||||
try:
|
||||
cleaned = content.strip()
|
||||
if cleaned.startswith("```json"):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
result: dict[str, Any] = json.loads(cleaned)
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[AC-AISVC-119] Failed to parse LLM response: {e}")
|
||||
return {}
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
Intent routing data models.
|
||||
[AC-AISVC-111~AC-AISVC-125] Data models for hybrid routing.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleMatchResult:
|
||||
"""
|
||||
[AC-AISVC-112] Result of rule matching.
|
||||
Contains matched rule and score.
|
||||
"""
|
||||
rule_id: uuid.UUID | None
|
||||
rule: Any | None
|
||||
match_type: str | None
|
||||
matched_text: str | None
|
||||
score: float
|
||||
duration_ms: int
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"rule_id": str(self.rule_id) if self.rule_id else None,
|
||||
"rule_name": self.rule.name if self.rule else None,
|
||||
"match_type": self.match_type,
|
||||
"matched_text": self.matched_text,
|
||||
"score": self.score,
|
||||
"duration_ms": self.duration_ms,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticCandidate:
|
||||
"""
|
||||
[AC-AISVC-113] Semantic match candidate.
|
||||
"""
|
||||
rule: Any
|
||||
score: float
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"rule_id": str(self.rule.id),
|
||||
"rule_name": self.rule.name,
|
||||
"score": self.score,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticMatchResult:
|
||||
"""
|
||||
[AC-AISVC-113] Result of semantic matching.
|
||||
"""
|
||||
candidates: list[SemanticCandidate]
|
||||
top_score: float
|
||||
duration_ms: int
|
||||
skipped: bool
|
||||
skip_reason: str | None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"top_candidates": [c.to_dict() for c in self.candidates],
|
||||
"top_score": self.top_score,
|
||||
"duration_ms": self.duration_ms,
|
||||
"skipped": self.skipped,
|
||||
"skip_reason": self.skip_reason,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmJudgeInput:
|
||||
"""
|
||||
[AC-AISVC-119] Input for LLM judge.
|
||||
"""
|
||||
message: str
|
||||
candidates: list[dict[str, Any]]
|
||||
conflict_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmJudgeResult:
|
||||
"""
|
||||
[AC-AISVC-119] Result of LLM judge.
|
||||
"""
|
||||
intent_id: str | None
|
||||
intent_name: str | None
|
||||
score: float
|
||||
reasoning: str | None
|
||||
duration_ms: int
|
||||
tokens_used: int
|
||||
triggered: bool
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"triggered": self.triggered,
|
||||
"intent_id": self.intent_id,
|
||||
"intent_name": self.intent_name,
|
||||
"score": self.score,
|
||||
"reasoning": self.reasoning,
|
||||
"duration_ms": self.duration_ms,
|
||||
"tokens_used": self.tokens_used,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "LlmJudgeResult":
|
||||
return cls(
|
||||
intent_id=None,
|
||||
intent_name=None,
|
||||
score=0.0,
|
||||
reasoning=None,
|
||||
duration_ms=0,
|
||||
tokens_used=0,
|
||||
triggered=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionConfig:
|
||||
"""
|
||||
[AC-AISVC-116] Fusion configuration.
|
||||
"""
|
||||
w_rule: float = 0.5
|
||||
w_semantic: float = 0.3
|
||||
w_llm: float = 0.2
|
||||
semantic_threshold: float = 0.7
|
||||
conflict_threshold: float = 0.2
|
||||
gray_zone_threshold: float = 0.6
|
||||
min_trigger_threshold: float = 0.3
|
||||
clarify_threshold: float = 0.4
|
||||
multi_intent_threshold: float = 0.15
|
||||
llm_judge_enabled: bool = True
|
||||
semantic_matcher_enabled: bool = True
|
||||
semantic_matcher_timeout_ms: int = 100
|
||||
llm_judge_timeout_ms: int = 2000
|
||||
semantic_top_k: int = 3
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"w_rule": self.w_rule,
|
||||
"w_semantic": self.w_semantic,
|
||||
"w_llm": self.w_llm,
|
||||
"semantic_threshold": self.semantic_threshold,
|
||||
"conflict_threshold": self.conflict_threshold,
|
||||
"gray_zone_threshold": self.gray_zone_threshold,
|
||||
"min_trigger_threshold": self.min_trigger_threshold,
|
||||
"clarify_threshold": self.clarify_threshold,
|
||||
"multi_intent_threshold": self.multi_intent_threshold,
|
||||
"llm_judge_enabled": self.llm_judge_enabled,
|
||||
"semantic_matcher_enabled": self.semantic_matcher_enabled,
|
||||
"semantic_matcher_timeout_ms": self.semantic_matcher_timeout_ms,
|
||||
"llm_judge_timeout_ms": self.llm_judge_timeout_ms,
|
||||
"semantic_top_k": self.semantic_top_k,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "FusionConfig":
|
||||
return cls(
|
||||
w_rule=data.get("w_rule", 0.5),
|
||||
w_semantic=data.get("w_semantic", 0.3),
|
||||
w_llm=data.get("w_llm", 0.2),
|
||||
semantic_threshold=data.get("semantic_threshold", 0.7),
|
||||
conflict_threshold=data.get("conflict_threshold", 0.2),
|
||||
gray_zone_threshold=data.get("gray_zone_threshold", 0.6),
|
||||
min_trigger_threshold=data.get("min_trigger_threshold", 0.3),
|
||||
clarify_threshold=data.get("clarify_threshold", 0.4),
|
||||
multi_intent_threshold=data.get("multi_intent_threshold", 0.15),
|
||||
llm_judge_enabled=data.get("llm_judge_enabled", True),
|
||||
semantic_matcher_enabled=data.get("semantic_matcher_enabled", True),
|
||||
semantic_matcher_timeout_ms=data.get("semantic_matcher_timeout_ms", 100),
|
||||
llm_judge_timeout_ms=data.get("llm_judge_timeout_ms", 2000),
|
||||
semantic_top_k=data.get("semantic_top_k", 3),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteTrace:
|
||||
"""
|
||||
[AC-AISVC-122] Route trace log.
|
||||
"""
|
||||
rule_match: dict[str, Any] = field(default_factory=dict)
|
||||
semantic_match: dict[str, Any] = field(default_factory=dict)
|
||||
llm_judge: dict[str, Any] = field(default_factory=dict)
|
||||
fusion: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"rule_match": self.rule_match,
|
||||
"semantic_match": self.semantic_match,
|
||||
"llm_judge": self.llm_judge,
|
||||
"fusion": self.fusion,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionResult:
|
||||
"""
|
||||
[AC-AISVC-115] Fusion decision result.
|
||||
"""
|
||||
final_intent: Any | None
|
||||
final_confidence: float
|
||||
decision_reason: str
|
||||
need_clarify: bool
|
||||
clarify_candidates: list[Any] | None
|
||||
trace: RouteTrace
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"final_intent": {
|
||||
"id": str(self.final_intent.id),
|
||||
"name": self.final_intent.name,
|
||||
"response_type": self.final_intent.response_type,
|
||||
} if self.final_intent else None,
|
||||
"final_confidence": self.final_confidence,
|
||||
"decision_reason": self.decision_reason,
|
||||
"need_clarify": self.need_clarify,
|
||||
"clarify_candidates": [
|
||||
{"id": str(c.id), "name": c.name}
|
||||
for c in (self.clarify_candidates or [])
|
||||
],
|
||||
"trace": self.trace.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_FUSION_CONFIG = FusionConfig()
|
||||
|
|
@ -1,14 +1,30 @@
|
|||
"""
|
||||
Intent router for AI Service.
|
||||
[AC-AISVC-69, AC-AISVC-70] Intent matching engine with keyword and regex support.
|
||||
[v0.8.0] Upgraded to hybrid routing with RuleMatcher + SemanticMatcher + LlmJudge + FusionPolicy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.models.entities import IntentRule
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
FusionResult,
|
||||
LlmJudgeInput,
|
||||
LlmJudgeResult,
|
||||
RouteTrace,
|
||||
RuleMatchResult,
|
||||
SemanticMatchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.intent.fusion_policy import FusionPolicy
|
||||
from app.services.intent.llm_judge import LlmJudge
|
||||
from app.services.intent.semantic_matcher import SemanticMatcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -38,38 +54,36 @@ class IntentMatchResult:
|
|||
}
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
class RuleMatcher:
|
||||
"""
|
||||
[AC-AISVC-69] Intent matching engine.
|
||||
|
||||
Matching algorithm:
|
||||
1. Load rules ordered by priority DESC
|
||||
2. For each rule, try keyword matching first
|
||||
3. If no keyword match, try regex pattern matching
|
||||
4. Return first match (highest priority)
|
||||
5. If no match, return None (fallback to default RAG)
|
||||
[v0.8.0] Rule matcher for keyword and regex matching.
|
||||
Extracted from IntentRouter for hybrid routing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(
|
||||
self,
|
||||
message: str,
|
||||
rules: list[IntentRule],
|
||||
) -> IntentMatchResult | None:
|
||||
def match(self, message: str, rules: list[IntentRule]) -> RuleMatchResult:
|
||||
"""
|
||||
[AC-AISVC-69] Match user message against intent rules.
|
||||
[AC-AISVC-112] Match user message against intent rules.
|
||||
Returns RuleMatchResult with score (1.0 for match, 0.0 for no match).
|
||||
|
||||
Args:
|
||||
message: User input message
|
||||
rules: List of enabled rules ordered by priority DESC
|
||||
|
||||
Returns:
|
||||
IntentMatchResult if matched, None otherwise
|
||||
RuleMatchResult with match details
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not message or not rules:
|
||||
return None
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
return RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
message_lower = message.lower()
|
||||
|
||||
|
|
@ -79,22 +93,46 @@ class IntentRouter:
|
|||
|
||||
keyword_result = self._match_keywords(message, message_lower, rule)
|
||||
if keyword_result:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(
|
||||
f"[AC-AISVC-69] Intent matched by keyword: "
|
||||
f"rule={rule.name}, matched='{keyword_result.matched}'"
|
||||
)
|
||||
return keyword_result
|
||||
return RuleMatchResult(
|
||||
rule_id=rule.id,
|
||||
rule=rule,
|
||||
match_type="keyword",
|
||||
matched_text=keyword_result.matched,
|
||||
score=1.0,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
regex_result = self._match_patterns(message, rule)
|
||||
if regex_result:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(
|
||||
f"[AC-AISVC-69] Intent matched by regex: "
|
||||
f"rule={rule.name}, matched='{regex_result.matched}'"
|
||||
)
|
||||
return regex_result
|
||||
return RuleMatchResult(
|
||||
rule_id=rule.id,
|
||||
rule=rule,
|
||||
match_type="regex",
|
||||
matched_text=regex_result.matched,
|
||||
score=1.0,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.debug("[AC-AISVC-70] No intent matched, will fallback to default RAG")
|
||||
return None
|
||||
return RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
def _match_keywords(
|
||||
self,
|
||||
|
|
@ -153,6 +191,74 @@ class IntentRouter:
|
|||
|
||||
return None
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
"""
|
||||
[AC-AISVC-69] Intent matching engine.
|
||||
[v0.8.0] Upgraded to support hybrid routing.
|
||||
|
||||
Matching algorithm:
|
||||
1. Load rules ordered by priority DESC
|
||||
2. For each rule, try keyword matching first
|
||||
3. If no keyword match, try regex pattern matching
|
||||
4. Return first match (highest priority)
|
||||
5. If no match, return None (fallback to default RAG)
|
||||
|
||||
Hybrid routing (match_hybrid):
|
||||
1. Parallel execute RuleMatcher + SemanticMatcher
|
||||
2. Conditionally trigger LlmJudge
|
||||
3. Execute FusionPolicy for final decision
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rule_matcher: RuleMatcher | None = None,
|
||||
semantic_matcher: "SemanticMatcher | None" = None,
|
||||
llm_judge: "LlmJudge | None" = None,
|
||||
fusion_policy: "FusionPolicy | None" = None,
|
||||
config: FusionConfig | None = None,
|
||||
):
|
||||
"""
|
||||
[v0.8.0] Initialize with optional dependencies for DI.
|
||||
|
||||
Args:
|
||||
rule_matcher: Rule matcher for keyword/regex matching
|
||||
semantic_matcher: Semantic matcher for vector similarity
|
||||
llm_judge: LLM judge for arbitration
|
||||
fusion_policy: Fusion policy for decision making
|
||||
config: Fusion configuration
|
||||
"""
|
||||
self._rule_matcher = rule_matcher or RuleMatcher()
|
||||
self._semantic_matcher = semantic_matcher
|
||||
self._llm_judge = llm_judge
|
||||
self._fusion_policy = fusion_policy
|
||||
self._config = config or FusionConfig()
|
||||
|
||||
def match(
|
||||
self,
|
||||
message: str,
|
||||
rules: list[IntentRule],
|
||||
) -> IntentMatchResult | None:
|
||||
"""
|
||||
[AC-AISVC-69] Match user message against intent rules.
|
||||
Preserved for backward compatibility.
|
||||
|
||||
Args:
|
||||
message: User input message
|
||||
rules: List of enabled rules ordered by priority DESC
|
||||
|
||||
Returns:
|
||||
IntentMatchResult if matched, None otherwise
|
||||
"""
|
||||
result = self._rule_matcher.match(message, rules)
|
||||
if result.rule:
|
||||
return IntentMatchResult(
|
||||
rule=result.rule,
|
||||
match_type=result.match_type or "keyword",
|
||||
matched=result.matched_text or "",
|
||||
)
|
||||
return None
|
||||
|
||||
def match_with_stats(
|
||||
self,
|
||||
message: str,
|
||||
|
|
@ -168,3 +274,300 @@ class IntentRouter:
|
|||
if result:
|
||||
return result, str(result.rule.id)
|
||||
return None, None
|
||||
|
||||
async def match_hybrid(
|
||||
self,
|
||||
message: str,
|
||||
rules: list[IntentRule],
|
||||
tenant_id: str,
|
||||
config: FusionConfig | None = None,
|
||||
) -> FusionResult:
|
||||
"""
|
||||
[AC-AISVC-111] Hybrid routing entry point.
|
||||
|
||||
Flow:
|
||||
1. Parallel execute RuleMatcher + SemanticMatcher
|
||||
2. Check if LlmJudge should trigger
|
||||
3. Execute FusionPolicy for final decision
|
||||
|
||||
Args:
|
||||
message: User input message
|
||||
rules: List of enabled rules ordered by priority DESC
|
||||
tenant_id: Tenant ID for isolation
|
||||
config: Optional fusion config override
|
||||
|
||||
Returns:
|
||||
FusionResult with final intent, confidence, and trace
|
||||
"""
|
||||
effective_config = config or self._config
|
||||
start_time = time.time()
|
||||
|
||||
rule_result = self._rule_matcher.match(message, rules)
|
||||
|
||||
semantic_result = await self._execute_semantic_matcher(
|
||||
message, rules, tenant_id, effective_config
|
||||
)
|
||||
|
||||
llm_result = await self._conditionally_execute_llm_judge(
|
||||
message, rule_result, semantic_result, tenant_id, effective_config
|
||||
)
|
||||
|
||||
if self._fusion_policy:
|
||||
fusion_result = self._fusion_policy.fuse(
|
||||
rule_result, semantic_result, llm_result
|
||||
)
|
||||
else:
|
||||
fusion_result = self._default_fusion(
|
||||
rule_result, semantic_result, llm_result, effective_config
|
||||
)
|
||||
|
||||
total_duration_ms = int((time.time() - start_time) * 1000)
|
||||
fusion_result.trace.fusion["total_duration_ms"] = total_duration_ms
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-111] Hybrid routing completed: "
|
||||
f"decision={fusion_result.decision_reason}, "
|
||||
f"confidence={fusion_result.final_confidence:.3f}, "
|
||||
f"duration={total_duration_ms}ms"
|
||||
)
|
||||
|
||||
return fusion_result
|
||||
|
||||
async def _execute_semantic_matcher(
|
||||
self,
|
||||
message: str,
|
||||
rules: list[IntentRule],
|
||||
tenant_id: str,
|
||||
config: FusionConfig,
|
||||
) -> SemanticMatchResult:
|
||||
"""Execute semantic matcher if available and enabled."""
|
||||
if not self._semantic_matcher:
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=0,
|
||||
skipped=True,
|
||||
skip_reason="not_configured",
|
||||
)
|
||||
|
||||
if not config.semantic_matcher_enabled:
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=0,
|
||||
skipped=True,
|
||||
skip_reason="disabled",
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._semantic_matcher.match(
|
||||
message=message,
|
||||
rules=rules,
|
||||
tenant_id=tenant_id,
|
||||
top_k=config.semantic_top_k,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-113] Semantic matcher failed: {e}")
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=0,
|
||||
skipped=True,
|
||||
skip_reason=f"error: {str(e)}",
|
||||
)
|
||||
|
||||
async def _conditionally_execute_llm_judge(
|
||||
self,
|
||||
message: str,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
tenant_id: str,
|
||||
config: FusionConfig,
|
||||
) -> LlmJudgeResult | None:
|
||||
"""Conditionally execute LLM judge based on trigger conditions."""
|
||||
if not self._llm_judge:
|
||||
return None
|
||||
|
||||
if not config.llm_judge_enabled:
|
||||
return None
|
||||
|
||||
should_trigger, trigger_reason = self._check_llm_trigger(
|
||||
rule_result, semantic_result, config
|
||||
)
|
||||
|
||||
if not should_trigger:
|
||||
return None
|
||||
|
||||
logger.info(f"[AC-AISVC-118] LLM judge triggered: reason={trigger_reason}")
|
||||
|
||||
candidates = self._build_llm_candidates(rule_result, semantic_result)
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await self._llm_judge.judge(
|
||||
LlmJudgeInput(
|
||||
message=message,
|
||||
candidates=candidates,
|
||||
conflict_type=trigger_reason,
|
||||
),
|
||||
tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-119] LLM judge failed: {e}")
|
||||
return LlmJudgeResult(
|
||||
intent_id=None,
|
||||
intent_name=None,
|
||||
score=0.0,
|
||||
reasoning=f"LLM error: {str(e)}",
|
||||
duration_ms=0,
|
||||
tokens_used=0,
|
||||
triggered=True,
|
||||
)
|
||||
|
||||
def _check_llm_trigger(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
config: FusionConfig,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
[AC-AISVC-118] Check if LLM judge should trigger.
|
||||
|
||||
Trigger conditions:
|
||||
1. Conflict: RuleMatcher and SemanticMatcher match different intents
|
||||
2. Gray zone: Max confidence in gray zone range
|
||||
3. Multi-intent: Multiple candidates with close scores
|
||||
|
||||
Returns:
|
||||
(should_trigger, trigger_reason)
|
||||
"""
|
||||
rule_score = rule_result.score
|
||||
semantic_score = semantic_result.top_score
|
||||
|
||||
if rule_score > 0 and semantic_score > 0 and not semantic_result.skipped:
|
||||
if semantic_result.candidates:
|
||||
top_semantic_rule_id = semantic_result.candidates[0].rule.id
|
||||
if rule_result.rule_id != top_semantic_rule_id:
|
||||
if abs(rule_score - semantic_score) < config.conflict_threshold:
|
||||
return True, "rule_semantic_conflict"
|
||||
|
||||
max_score = max(rule_score, semantic_score)
|
||||
if config.min_trigger_threshold < max_score < config.gray_zone_threshold:
|
||||
return True, "gray_zone"
|
||||
|
||||
if len(semantic_result.candidates) >= 2:
|
||||
top1_score = semantic_result.candidates[0].score
|
||||
top2_score = semantic_result.candidates[1].score
|
||||
if abs(top1_score - top2_score) < config.multi_intent_threshold:
|
||||
return True, "multi_intent"
|
||||
|
||||
return False, ""
|
||||
|
||||
def _build_llm_candidates(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build candidate list for LLM judge."""
|
||||
candidates = []
|
||||
|
||||
if rule_result.rule:
|
||||
candidates.append({
|
||||
"id": str(rule_result.rule_id),
|
||||
"name": rule_result.rule.name,
|
||||
"description": f"匹配方式: {rule_result.match_type}, 匹配内容: {rule_result.matched_text}",
|
||||
})
|
||||
|
||||
for candidate in semantic_result.candidates[:3]:
|
||||
if not any(c["id"] == str(candidate.rule.id) for c in candidates):
|
||||
candidates.append({
|
||||
"id": str(candidate.rule.id),
|
||||
"name": candidate.rule.name,
|
||||
"description": f"语义相似度: {candidate.score:.2f}",
|
||||
})
|
||||
|
||||
return candidates
|
||||
|
||||
def _default_fusion(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
llm_result: LlmJudgeResult | None,
|
||||
config: FusionConfig,
|
||||
) -> FusionResult:
|
||||
"""Default fusion logic when FusionPolicy is not available."""
|
||||
trace = RouteTrace(
|
||||
rule_match=rule_result.to_dict(),
|
||||
semantic_match=semantic_result.to_dict(),
|
||||
llm_judge=llm_result.to_dict() if llm_result else {},
|
||||
fusion={},
|
||||
)
|
||||
|
||||
final_intent = None
|
||||
final_confidence = 0.0
|
||||
decision_reason = "no_match"
|
||||
|
||||
if rule_result.score == 1.0 and rule_result.rule:
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = 1.0
|
||||
decision_reason = "rule_high_confidence"
|
||||
elif llm_result and llm_result.triggered and llm_result.intent_id:
|
||||
final_intent = self._find_rule_by_id(
|
||||
llm_result.intent_id, rule_result, semantic_result
|
||||
)
|
||||
final_confidence = llm_result.score
|
||||
decision_reason = "llm_judge"
|
||||
elif rule_result.score == 0 and semantic_result.top_score > config.semantic_threshold:
|
||||
if semantic_result.candidates:
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
decision_reason = "semantic_override"
|
||||
elif semantic_result.top_score > 0.5:
|
||||
if semantic_result.candidates:
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
decision_reason = "semantic_fallback"
|
||||
|
||||
need_clarify = final_confidence < config.clarify_threshold
|
||||
clarify_candidates = None
|
||||
if need_clarify and len(semantic_result.candidates) > 1:
|
||||
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
|
||||
|
||||
trace.fusion = {
|
||||
"weights": {
|
||||
"w_rule": config.w_rule,
|
||||
"w_semantic": config.w_semantic,
|
||||
"w_llm": config.w_llm,
|
||||
},
|
||||
"final_confidence": final_confidence,
|
||||
"decision_reason": decision_reason,
|
||||
}
|
||||
|
||||
return FusionResult(
|
||||
final_intent=final_intent,
|
||||
final_confidence=final_confidence,
|
||||
decision_reason=decision_reason,
|
||||
need_clarify=need_clarify,
|
||||
clarify_candidates=clarify_candidates,
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
def _find_rule_by_id(
|
||||
self,
|
||||
intent_id: str | None,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
) -> IntentRule | None:
|
||||
"""Find rule by ID from rule or semantic results."""
|
||||
if not intent_id:
|
||||
return None
|
||||
|
||||
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
|
||||
return rule_result.rule
|
||||
|
||||
for candidate in semantic_result.candidates:
|
||||
if str(candidate.rule.id) == intent_id:
|
||||
return candidate.rule
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -106,6 +106,8 @@ class IntentRuleService:
|
|||
is_enabled=True,
|
||||
hit_count=0,
|
||||
metadata_=create_data.metadata_,
|
||||
intent_vector=create_data.intent_vector,
|
||||
semantic_examples=create_data.semantic_examples,
|
||||
)
|
||||
self._session.add(rule)
|
||||
await self._session.flush()
|
||||
|
|
@ -195,6 +197,10 @@ class IntentRuleService:
|
|||
rule.is_enabled = update_data.is_enabled
|
||||
if update_data.metadata_ is not None:
|
||||
rule.metadata_ = update_data.metadata_
|
||||
if update_data.intent_vector is not None:
|
||||
rule.intent_vector = update_data.intent_vector
|
||||
if update_data.semantic_examples is not None:
|
||||
rule.semantic_examples = update_data.semantic_examples
|
||||
|
||||
rule.updated_at = datetime.utcnow()
|
||||
await self._session.flush()
|
||||
|
|
@ -267,7 +273,7 @@ class IntentRuleService:
|
|||
select(IntentRule)
|
||||
.where(
|
||||
IntentRule.tenant_id == tenant_id,
|
||||
IntentRule.is_enabled == True,
|
||||
IntentRule.is_enabled == True, # noqa: E712
|
||||
)
|
||||
.order_by(col(IntentRule.priority).desc())
|
||||
)
|
||||
|
|
@ -300,6 +306,8 @@ class IntentRuleService:
|
|||
"is_enabled": rule.is_enabled,
|
||||
"hit_count": rule.hit_count,
|
||||
"metadata": rule.metadata_,
|
||||
"created_at": rule.created_at.isoformat(),
|
||||
"updated_at": rule.updated_at.isoformat(),
|
||||
"intent_vector": rule.intent_vector,
|
||||
"semantic_examples": rule.semantic_examples,
|
||||
"created_at": rule.created_at.isoformat() if rule.created_at else None,
|
||||
"updated_at": rule.updated_at.isoformat() if rule.updated_at else None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,233 @@
|
|||
"""
|
||||
Semantic matcher for intent recognition.
|
||||
[AC-AISVC-113, AC-AISVC-114] Vector-based semantic matching.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
SemanticCandidate,
|
||||
SemanticMatchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.entities import IntentRule
|
||||
from app.services.embedding.base import EmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SemanticMatcher:
|
||||
"""
|
||||
[AC-AISVC-113] Semantic matcher using vector similarity.
|
||||
|
||||
Supports two matching modes:
|
||||
- Mode A: Use pre-computed intent_vector for direct similarity calculation
|
||||
- Mode B: Use semantic_examples for dynamic vector computation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_provider: "EmbeddingProvider",
|
||||
config: FusionConfig,
|
||||
):
|
||||
"""
|
||||
Initialize semantic matcher.
|
||||
|
||||
Args:
|
||||
embedding_provider: Provider for generating embeddings
|
||||
config: Fusion configuration
|
||||
"""
|
||||
self._embedding_provider = embedding_provider
|
||||
self._config = config
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message: str,
|
||||
rules: list["IntentRule"],
|
||||
tenant_id: str,
|
||||
top_k: int | None = None,
|
||||
) -> SemanticMatchResult:
|
||||
"""
|
||||
[AC-AISVC-113] Perform vector semantic matching.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
rules: List of intent rules
|
||||
tenant_id: Tenant ID for isolation
|
||||
top_k: Number of top candidates to return
|
||||
|
||||
Returns:
|
||||
SemanticMatchResult with candidates and scores
|
||||
"""
|
||||
start_time = time.time()
|
||||
effective_top_k = top_k or self._config.semantic_top_k
|
||||
|
||||
if not self._config.semantic_matcher_enabled:
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=0,
|
||||
skipped=True,
|
||||
skip_reason="disabled",
|
||||
)
|
||||
|
||||
rules_with_semantic = [r for r in rules if self._has_semantic_config(r)]
|
||||
if not rules_with_semantic:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.debug(
|
||||
f"[AC-AISVC-113] No rules with semantic config for tenant={tenant_id}"
|
||||
)
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=duration_ms,
|
||||
skipped=True,
|
||||
skip_reason="no_semantic_config",
|
||||
)
|
||||
|
||||
try:
|
||||
message_vector = await asyncio.wait_for(
|
||||
self._embedding_provider.embed(message),
|
||||
timeout=self._config.semantic_matcher_timeout_ms / 1000,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-AISVC-113] Embedding timeout for tenant={tenant_id}, "
|
||||
f"timeout={self._config.semantic_matcher_timeout_ms}ms"
|
||||
)
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=duration_ms,
|
||||
skipped=True,
|
||||
skip_reason="embedding_timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-AISVC-113] Embedding error for tenant={tenant_id}: {e}"
|
||||
)
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=duration_ms,
|
||||
skipped=True,
|
||||
skip_reason=f"embedding_error: {str(e)}",
|
||||
)
|
||||
|
||||
candidates = []
|
||||
for rule in rules_with_semantic:
|
||||
try:
|
||||
score = await self._calculate_similarity(message_vector, rule)
|
||||
if score > 0:
|
||||
candidates.append(SemanticCandidate(rule=rule, score=score))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-114] Similarity calculation failed for rule={rule.id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||
candidates = candidates[:effective_top_k]
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(
|
||||
f"[AC-AISVC-113] Semantic match completed for tenant={tenant_id}, "
|
||||
f"candidates={len(candidates)}, top_score={candidates[0].score if candidates else 0:.3f}, "
|
||||
f"duration={duration_ms}ms"
|
||||
)
|
||||
|
||||
return SemanticMatchResult(
|
||||
candidates=candidates,
|
||||
top_score=candidates[0].score if candidates else 0.0,
|
||||
duration_ms=duration_ms,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
def _has_semantic_config(self, rule: "IntentRule") -> bool:
|
||||
"""
|
||||
Check if rule has semantic configuration.
|
||||
|
||||
Args:
|
||||
rule: Intent rule to check
|
||||
|
||||
Returns:
|
||||
True if rule has intent_vector or semantic_examples
|
||||
"""
|
||||
return bool(rule.intent_vector) or bool(rule.semantic_examples)
|
||||
|
||||
async def _calculate_similarity(
|
||||
self,
|
||||
message_vector: list[float],
|
||||
rule: "IntentRule",
|
||||
) -> float:
|
||||
"""
|
||||
[AC-AISVC-114] Calculate similarity between message and rule.
|
||||
|
||||
Mode A: Use pre-computed intent_vector
|
||||
Mode B: Use semantic_examples for dynamic computation
|
||||
|
||||
Args:
|
||||
message_vector: Message embedding vector
|
||||
rule: Intent rule with semantic config
|
||||
|
||||
Returns:
|
||||
Similarity score (0.0 ~ 1.0)
|
||||
"""
|
||||
if rule.intent_vector:
|
||||
return self._cosine_similarity(message_vector, rule.intent_vector)
|
||||
elif rule.semantic_examples:
|
||||
try:
|
||||
example_vectors = await self._embedding_provider.embed_batch(
|
||||
rule.semantic_examples
|
||||
)
|
||||
similarities = [
|
||||
self._cosine_similarity(message_vector, v)
|
||||
for v in example_vectors
|
||||
]
|
||||
return max(similarities) if similarities else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-114] Failed to compute example vectors for rule={rule.id}: {e}"
|
||||
)
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def _cosine_similarity(
|
||||
self,
|
||||
v1: list[float],
|
||||
v2: list[float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
v1: First vector
|
||||
v2: Second vector
|
||||
|
||||
Returns:
|
||||
Cosine similarity (0.0 ~ 1.0)
|
||||
"""
|
||||
if not v1 or not v2:
|
||||
return 0.0
|
||||
|
||||
v1_arr = np.array(v1)
|
||||
v2_arr = np.array(v2)
|
||||
|
||||
norm1 = np.linalg.norm(v1_arr)
|
||||
norm2 = np.linalg.norm(v2_arr)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
similarity = float(np.dot(v1_arr, v2_arr) / (norm1 * norm2))
|
||||
return max(0.0, min(1.0, similarity))
|
||||
|
|
@ -83,6 +83,7 @@ class KBService:
|
|||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> tuple[Document, IndexJob]:
|
||||
"""
|
||||
[AC-ASA-01] Upload document and create indexing job.
|
||||
|
|
@ -108,6 +109,7 @@ class KBService:
|
|||
file_size=len(file_content),
|
||||
file_type=file_type,
|
||||
status=DocumentStatus.PENDING.value,
|
||||
doc_metadata=metadata,
|
||||
)
|
||||
self._session.add(document)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,14 @@ LLM Adapter module for AI Service.
|
|||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
|
||||
"""
|
||||
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.base import (
|
||||
LLMClient,
|
||||
LLMConfig,
|
||||
LLMResponse,
|
||||
LLMStreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -12,4 +19,6 @@ __all__ = [
|
|||
"LLMResponse",
|
||||
"LLMStreamChunk",
|
||||
"OpenAIClient",
|
||||
"ToolCall",
|
||||
"ToolDefinition",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -28,18 +28,46 @@ class LLMConfig:
|
|||
extra_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""
|
||||
Represents a function call from the LLM.
|
||||
Used in Function Calling mode.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
import json
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""
|
||||
Response from LLM generation.
|
||||
[AC-AISVC-02] Contains generated content and metadata.
|
||||
"""
|
||||
content: str
|
||||
model: str
|
||||
content: str | None = None
|
||||
model: str = ""
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
finish_reason: str = "stop"
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMStreamChunk:
|
||||
|
|
@ -50,9 +78,33 @@ class LLMStreamChunk:
|
|||
delta: str
|
||||
model: str
|
||||
finish_reason: str | None = None
|
||||
tool_calls_delta: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""
|
||||
Tool definition for Function Calling.
|
||||
Compatible with OpenAI/DeepSeek function calling format.
|
||||
"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any]
|
||||
type: str = "function"
|
||||
|
||||
def to_openai_format(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI tools format."""
|
||||
return {
|
||||
"type": self.type,
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""
|
||||
Abstract base class for LLM clients.
|
||||
|
|
@ -67,6 +119,8 @@ class LLMClient(ABC):
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
|
@ -76,10 +130,12 @@ class LLMClient(ABC):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
LLMResponse with generated content, tool_calls, and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
|
|
@ -91,6 +147,8 @@ class LLMClient(ABC):
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
|
|
@ -100,6 +158,8 @@ class LLMClient(ABC):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
|
|
|
|||
|
|
@ -11,12 +11,16 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.llm.base import LLMClient, LLMConfig
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLM_CONFIG_FILE = Path("config/llm_config.json")
|
||||
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -286,6 +290,8 @@ class LLMConfigManager:
|
|||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
self._settings = settings
|
||||
self._redis_client: redis.Redis | None = None
|
||||
|
||||
self._current_provider: str = settings.llm_provider
|
||||
self._current_config: dict[str, Any] = {
|
||||
|
|
@ -299,8 +305,75 @@ class LLMConfigManager:
|
|||
}
|
||||
self._client: LLMClient | None = None
|
||||
|
||||
self._load_from_redis()
|
||||
self._load_from_file()
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||
|
||||
def _save_to_redis(self) -> None:
|
||||
"""Save configuration to Redis."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
if self._redis_client is None:
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
self._redis_client.set(
|
||||
LLM_CONFIG_REDIS_KEY,
|
||||
json.dumps({
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
}, ensure_ascii=False),
|
||||
)
|
||||
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
|
|
@ -364,6 +437,7 @@ class LLMConfigManager:
|
|||
self._current_provider = provider
|
||||
self._current_config = validated_config
|
||||
|
||||
self._save_to_redis()
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,14 @@ from tenacity import (
|
|||
|
||||
from app.core.config import get_settings
|
||||
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.base import (
|
||||
LLMClient,
|
||||
LLMConfig,
|
||||
LLMResponse,
|
||||
LLMStreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -95,6 +102,8 @@ class OpenAIClient(LLMClient):
|
|||
messages: list[dict[str, str]],
|
||||
config: LLMConfig,
|
||||
stream: bool = False,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request body for OpenAI API."""
|
||||
|
|
@ -106,6 +115,13 @@ class OpenAIClient(LLMClient):
|
|||
"top_p": config.top_p,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = [tool.to_openai_format() for tool in tools]
|
||||
|
||||
if tool_choice:
|
||||
body["tool_choice"] = tool_choice
|
||||
|
||||
body.update(config.extra_params)
|
||||
body.update(kwargs)
|
||||
return body
|
||||
|
|
@ -119,6 +135,8 @@ class OpenAIClient(LLMClient):
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
|
@ -128,10 +146,12 @@ class OpenAIClient(LLMClient):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
LLMResponse with generated content, tool_calls, and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
|
|
@ -140,9 +160,14 @@ class OpenAIClient(LLMClient):
|
|||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
||||
body = self._build_request_body(
|
||||
messages, effective_config, stream=False,
|
||||
tools=tools, tool_choice=tool_choice, **kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
||||
if tools:
|
||||
logger.info(f"[AC-AISVC-02] Function calling enabled with {len(tools)} tools")
|
||||
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
|
|
@ -177,14 +202,18 @@ class OpenAIClient(LLMClient):
|
|||
|
||||
try:
|
||||
choice = data["choices"][0]
|
||||
content = choice["message"]["content"]
|
||||
message = choice["message"]
|
||||
content = message.get("content")
|
||||
usage = data.get("usage", {})
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
tool_calls = self._parse_tool_calls(message)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Generated response: "
|
||||
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
||||
f"finish_reason={finish_reason}"
|
||||
f"finish_reason={finish_reason}, "
|
||||
f"tool_calls={len(tool_calls)}"
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
|
|
@ -192,6 +221,7 @@ class OpenAIClient(LLMClient):
|
|||
model=data.get("model", effective_config.model),
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
metadata={"raw_response": data},
|
||||
)
|
||||
|
||||
|
|
@ -202,10 +232,33 @@ class OpenAIClient(LLMClient):
|
|||
details=[{"response": str(data)}],
|
||||
)
|
||||
|
||||
def _parse_tool_calls(self, message: dict[str, Any]) -> list[ToolCall]:
|
||||
"""Parse tool calls from LLM response message."""
|
||||
tool_calls = []
|
||||
raw_tool_calls = message.get("tool_calls", [])
|
||||
|
||||
for tc in raw_tool_calls:
|
||||
if tc.get("type") == "function":
|
||||
func = tc.get("function", {})
|
||||
try:
|
||||
arguments = json.loads(func.get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append(ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=arguments,
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
|
|
@ -215,6 +268,8 @@ class OpenAIClient(LLMClient):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
|
|
@ -227,9 +282,14 @@ class OpenAIClient(LLMClient):
|
|||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
||||
body = self._build_request_body(
|
||||
messages, effective_config, stream=True,
|
||||
tools=tools, tool_choice=tool_choice, **kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
||||
if tools:
|
||||
logger.info(f"[AC-AISVC-06] Function calling enabled with {len(tools)} tools")
|
||||
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
元数据字段定义缓存服务
|
||||
使用 Redis 缓存 metadata_field_definitions,减少数据库查询
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataCacheService:
|
||||
"""
|
||||
元数据字段定义缓存服务
|
||||
|
||||
缓存策略:
|
||||
- Key: metadata:fields:{tenant_id}
|
||||
- Value: JSON 序列化的字段定义列表
|
||||
- TTL: 1小时(3600秒)
|
||||
- 更新策略:写时更新 + 定时刷新
|
||||
"""
|
||||
|
||||
CACHE_KEY_PREFIX = "metadata:fields"
|
||||
DEFAULT_TTL = 3600 # 1小时
|
||||
|
||||
def __init__(self):
|
||||
self._settings = get_settings()
|
||||
self._redis_client = None
|
||||
self._enabled = self._settings.redis_enabled
|
||||
|
||||
async def _get_redis(self):
|
||||
"""获取 Redis 连接(延迟初始化)"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
if self._redis_client is None:
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
decode_responses=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Failed to connect Redis: {e}")
|
||||
self._enabled = False
|
||||
return None
|
||||
|
||||
return self._redis_client
|
||||
|
||||
def _make_key(self, tenant_id: str) -> str:
|
||||
"""生成缓存 key"""
|
||||
return f"{self.CACHE_KEY_PREFIX}:{tenant_id}"
|
||||
|
||||
async def get_fields(self, tenant_id: str) -> list[dict[str, Any]] | None:
|
||||
"""
|
||||
获取缓存的字段定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
字段定义列表,未缓存返回 None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return None
|
||||
|
||||
key = self._make_key(tenant_id)
|
||||
cached_data = await redis.get(key)
|
||||
|
||||
if cached_data:
|
||||
logger.info(f"[MetadataCache] Cache hit for tenant={tenant_id}")
|
||||
return json.loads(cached_data)
|
||||
|
||||
logger.info(f"[MetadataCache] Cache miss for tenant={tenant_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Get cache error: {e}")
|
||||
return None
|
||||
|
||||
async def set_fields(
|
||||
self,
|
||||
tenant_id: str,
|
||||
fields: list[dict[str, Any]],
|
||||
ttl: int | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
缓存字段定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
fields: 字段定义列表
|
||||
ttl: 过期时间(秒),默认 1小时
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id)
|
||||
ttl = ttl or self.DEFAULT_TTL
|
||||
|
||||
await redis.setex(
|
||||
key,
|
||||
ttl,
|
||||
json.dumps(fields, ensure_ascii=False, default=str)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataCache] Cached {len(fields)} fields for tenant={tenant_id}, "
|
||||
f"ttl={ttl}s"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Set cache error: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
使缓存失效(字段定义更新时调用)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id)
|
||||
result = await redis.delete(key)
|
||||
|
||||
if result:
|
||||
logger.info(f"[MetadataCache] Invalidated cache for tenant={tenant_id}")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Invalidate error: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate_all(self) -> bool:
|
||||
"""
|
||||
使所有元数据缓存失效
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
# 查找所有元数据缓存 key
|
||||
pattern = f"{self.CACHE_KEY_PREFIX}:*"
|
||||
keys = []
|
||||
async for key in redis.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
logger.info(f"[MetadataCache] Invalidated {len(keys)} cache entries")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Invalidate all error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局缓存服务实例
|
||||
_metadata_cache_service: MetadataCacheService | None = None
|
||||
|
||||
|
||||
async def get_metadata_cache_service() -> MetadataCacheService:
|
||||
"""获取元数据缓存服务实例(单例)"""
|
||||
global _metadata_cache_service
|
||||
if _metadata_cache_service is None:
|
||||
_metadata_cache_service = MetadataCacheService()
|
||||
return _metadata_cache_service
|
||||
|
|
@ -40,6 +40,19 @@ class MetadataFieldDefinitionService:
|
|||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def _invalidate_cache(self, tenant_id: str) -> None:
|
||||
"""
|
||||
清除租户的元数据字段缓存
|
||||
在字段创建、更新、删除时调用
|
||||
"""
|
||||
try:
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
await cache_service.invalidate(tenant_id)
|
||||
except Exception as e:
|
||||
# 缓存失效失败不影响主流程
|
||||
logger.warning(f"[AC-IDSMETA-13] Failed to invalidate cache: {e}")
|
||||
|
||||
async def list_field_definitions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
|
@ -180,6 +193,9 @@ class MetadataFieldDefinitionService:
|
|||
self._session.add(field)
|
||||
await self._session.flush()
|
||||
|
||||
# 清除缓存,使新字段在下次查询时生效
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, "
|
||||
f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}"
|
||||
|
|
@ -223,6 +239,10 @@ class MetadataFieldDefinitionService:
|
|||
field.is_filterable = field_update.is_filterable
|
||||
if field_update.is_rank_feature is not None:
|
||||
field.is_rank_feature = field_update.is_rank_feature
|
||||
# [AC-MRS-01] 修复:添加 field_roles 更新逻辑
|
||||
if field_update.field_roles is not None:
|
||||
self._validate_field_roles(field_update.field_roles)
|
||||
field.field_roles = field_update.field_roles
|
||||
if field_update.status is not None:
|
||||
old_status = field.status
|
||||
field.status = field_update.status
|
||||
|
|
@ -235,6 +255,9 @@ class MetadataFieldDefinitionService:
|
|||
field.updated_at = datetime.utcnow()
|
||||
await self._session.flush()
|
||||
|
||||
# 清除缓存,使更新在下次查询时生效
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, "
|
||||
f"field_id={field_id}, version={field.version}"
|
||||
|
|
|
|||
|
|
@ -17,6 +17,18 @@ from .memory_adapter import MemoryAdapter, UserMemory
|
|||
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
|
||||
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
|
||||
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
|
||||
from .slot_validation_service import (
|
||||
SlotValidationService,
|
||||
ValidationResult,
|
||||
SlotValidationError,
|
||||
BatchValidationResult,
|
||||
SlotValidationErrorCode,
|
||||
)
|
||||
from .slot_manager import (
|
||||
SlotManager,
|
||||
SlotWriteResult,
|
||||
create_slot_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PolicyRouter",
|
||||
|
|
@ -54,4 +66,12 @@ __all__ = [
|
|||
"RuntimeObserver",
|
||||
"RuntimeContext",
|
||||
"get_runtime_observer",
|
||||
"SlotValidationService",
|
||||
"ValidationResult",
|
||||
"SlotValidationError",
|
||||
"BatchValidationResult",
|
||||
"SlotValidationErrorCode",
|
||||
"SlotManager",
|
||||
"SlotWriteResult",
|
||||
"create_slot_manager",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@
|
|||
Agent Orchestrator for Mid Platform.
|
||||
[AC-MARH-07] ReAct loop with iteration limit (3-5 iterations).
|
||||
|
||||
Supports two execution modes:
|
||||
1. ReAct (Text-based): Traditional Thought/Action/Observation loop
|
||||
2. Function Calling: Uses LLM's native function calling capability
|
||||
|
||||
ReAct Flow:
|
||||
1. Thought: Agent thinks about what to do
|
||||
2. Action: Agent decides to use a tool
|
||||
|
|
@ -16,6 +20,8 @@ import re
|
|||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
|
|
@ -26,7 +32,10 @@ from app.models.mid.schemas import (
|
|||
ToolType,
|
||||
TraceInfo,
|
||||
)
|
||||
from app.services.llm.base import ToolDefinition
|
||||
from app.services.mid.tool_guide_registry import ToolGuideRegistry, get_tool_guide_registry
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.mid.tool_converter import convert_tools_to_llm_format, build_tool_result_message
|
||||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
|
|
@ -36,11 +45,17 @@ DEFAULT_MAX_ITERATIONS = 5
|
|||
MIN_ITERATIONS = 3
|
||||
|
||||
|
||||
class AgentMode(str, Enum):
|
||||
"""Agent execution mode."""
|
||||
REACT = "react"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Tool execution result."""
|
||||
success: bool
|
||||
output: str | None = None
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
|
||||
|
|
@ -59,6 +74,7 @@ class AgentOrchestrator:
|
|||
|
||||
Features:
|
||||
- ReAct loop with max 5 iterations (min 3)
|
||||
- Function Calling mode for supported LLMs (OpenAI, DeepSeek, etc.)
|
||||
- Per-tool timeout (2s) and end-to-end timeout (8s)
|
||||
- Automatic fallback on iteration limit or timeout
|
||||
- Template-based prompt with variable injection
|
||||
|
|
@ -70,17 +86,74 @@ class AgentOrchestrator:
|
|||
timeout_governor: TimeoutGovernor | None = None,
|
||||
llm_client: Any = None,
|
||||
tool_registry: Any = None,
|
||||
guide_registry: ToolGuideRegistry | None = None,
|
||||
template_service: PromptTemplateService | None = None,
|
||||
variable_resolver: VariableResolver | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
scene: str | None = None,
|
||||
mode: AgentMode = AgentMode.FUNCTION_CALLING,
|
||||
):
|
||||
self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS)
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._llm_client = llm_client
|
||||
self._tool_registry = tool_registry
|
||||
self._guide_registry = guide_registry
|
||||
self._template_service = template_service
|
||||
self._variable_resolver = variable_resolver or VariableResolver()
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._session_id = session_id
|
||||
self._scene = scene
|
||||
self._mode = mode
|
||||
self._tools_cache: list[ToolDefinition] | None = None
|
||||
|
||||
def _get_tools_definition(self) -> list[ToolDefinition]:
|
||||
"""Get cached tools definition for Function Calling."""
|
||||
if self._tools_cache is None and self._tool_registry:
|
||||
tools = self._tool_registry.get_all_tools()
|
||||
self._tools_cache = convert_tools_to_llm_format(tools)
|
||||
return self._tools_cache or []
|
||||
|
||||
async def _get_tools_definition_async(self) -> list[ToolDefinition]:
|
||||
"""Get tools definition for Function Calling with dynamic schema support."""
|
||||
if self._tools_cache is not None:
|
||||
return self._tools_cache
|
||||
|
||||
if not self._tool_registry:
|
||||
return []
|
||||
|
||||
tools = self._tool_registry.get_all_tools()
|
||||
result = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.name == "kb_search_dynamic" and self._tenant_id:
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
_TOOL_SCHEMA_CACHE,
|
||||
_TOOL_SCHEMA_CACHE_TTL_SECONDS,
|
||||
)
|
||||
import time
|
||||
|
||||
cache_key = f"tool_schema:{self._tenant_id}"
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _TOOL_SCHEMA_CACHE:
|
||||
cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key]
|
||||
if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS:
|
||||
result.append(ToolDefinition(
|
||||
name=cached_schema["name"],
|
||||
description=cached_schema["description"],
|
||||
parameters=cached_schema["parameters"],
|
||||
))
|
||||
continue
|
||||
|
||||
result.append(convert_tool_to_llm_format(tool))
|
||||
else:
|
||||
result.append(convert_tool_to_llm_format(tool))
|
||||
|
||||
self._tools_cache = result
|
||||
return result
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
|
|
@ -90,7 +163,7 @@ class AgentOrchestrator:
|
|||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
[AC-MARH-07] Execute ReAct loop with iteration control.
|
||||
[AC-MARH-07] Execute agent loop with iteration control.
|
||||
|
||||
Args:
|
||||
user_message: User input message
|
||||
|
|
@ -101,6 +174,416 @@ class AgentOrchestrator:
|
|||
Returns:
|
||||
Tuple of (final_answer, react_context, trace_info)
|
||||
"""
|
||||
if self._mode == AgentMode.FUNCTION_CALLING:
|
||||
return await self._execute_function_calling(user_message, context, on_action)
|
||||
else:
|
||||
return await self._execute_react(user_message, context, on_thought, on_action)
|
||||
|
||||
async def _execute_function_calling(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
Execute using Function Calling mode.
|
||||
|
||||
This mode uses the LLM's native function calling capability,
|
||||
which is more reliable and token-efficient than text-based ReAct.
|
||||
"""
|
||||
react_ctx = ReActContext(max_iterations=self._max_iterations)
|
||||
tool_calls: list[ToolCallTrace] = []
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Starting Function Calling loop: max_iterations={self._max_iterations}, "
|
||||
f"llm_client={self._llm_client is not None}, tool_registry={self._tool_registry is not None}"
|
||||
)
|
||||
|
||||
if not self._llm_client:
|
||||
logger.error("[DEBUG-ORCH] LLM client is None, returning error response")
|
||||
return "抱歉,服务配置错误,请联系管理员。", react_ctx, TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
try:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": await self._build_system_prompt()},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
tools = await self._get_tools_definition_async()
|
||||
|
||||
overall_start = time.time()
|
||||
end_to_end_timeout = self._timeout_governor.end_to_end_timeout_seconds
|
||||
llm_timeout = getattr(self._timeout_governor, 'llm_timeout_seconds', 15.0)
|
||||
|
||||
while react_ctx.should_continue and react_ctx.iteration < react_ctx.max_iterations:
|
||||
react_ctx.iteration += 1
|
||||
|
||||
elapsed = time.time() - overall_start
|
||||
remaining_time = end_to_end_timeout - elapsed
|
||||
if remaining_time <= 0:
|
||||
logger.warning("[AC-MARH-09] Function Calling loop exceeded end-to-end timeout")
|
||||
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Function Calling iteration {react_ctx.iteration}/"
|
||||
f"{react_ctx.max_iterations}, remaining_time={remaining_time:.1f}s"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-ORCH] Calling LLM generate with messages_count={len(messages)}, "
|
||||
f"tools_count={len(tools) if tools else 0}"
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
self._llm_client.generate(
|
||||
messages=messages,
|
||||
tools=tools if tools else None,
|
||||
tool_choice="auto" if tools else None,
|
||||
),
|
||||
timeout=min(llm_timeout, remaining_time)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-ORCH] LLM response received: has_tool_calls={response.has_tool_calls}, "
|
||||
f"content_length={len(response.content) if response.content else 0}, "
|
||||
f"tool_calls_count={len(response.tool_calls) if response.tool_calls else 0}"
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
logger.info(f"[AC-MARH-07] Tool call: {tool_name}, args={tool_args}")
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content,
|
||||
"tool_calls": [tool_call.to_dict()],
|
||||
})
|
||||
|
||||
tool_result, tool_trace = await self._act_fc(
|
||||
tool_call.id, tool_name, tool_args, react_ctx
|
||||
)
|
||||
tool_calls.append(tool_trace)
|
||||
react_ctx.tool_calls.append(tool_trace)
|
||||
|
||||
if on_action:
|
||||
await on_action(tool_name, tool_result)
|
||||
|
||||
called_tools = {tc.tool_name for tc in react_ctx.tool_calls[:-1]}
|
||||
is_first_call = tool_name not in called_tools
|
||||
|
||||
# Extract tool_guide from output if present (added by _act_fc)
|
||||
result_output = tool_result.output if tool_result.success else {"error": tool_result.error}
|
||||
tool_guide = None
|
||||
if isinstance(result_output, dict) and "_tool_guide" in result_output:
|
||||
result_output = dict(result_output)
|
||||
tool_guide = result_output.pop("_tool_guide")
|
||||
|
||||
messages.append(build_tool_result_message(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_name,
|
||||
result=result_output,
|
||||
tool_guide=tool_guide,
|
||||
))
|
||||
|
||||
if not tool_result.success:
|
||||
if tool_trace.status == ToolCallStatus.TIMEOUT:
|
||||
react_ctx.final_answer = "抱歉,操作超时,请稍后重试或联系人工客服。"
|
||||
react_ctx.should_continue = False
|
||||
break
|
||||
else:
|
||||
react_ctx.final_answer = response.content or "抱歉,我无法处理您的请求。"
|
||||
react_ctx.should_continue = False
|
||||
break
|
||||
|
||||
if react_ctx.should_continue and not react_ctx.final_answer:
|
||||
logger.warning(f"[AC-MARH-07] Function Calling reached max iterations: {react_ctx.iteration}")
|
||||
react_ctx.final_answer = await self._force_final_answer_fc(messages)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[AC-MARH-09] Function Calling loop timed out (end-to-end)")
|
||||
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
|
||||
tool_calls.append(ToolCallTrace(
|
||||
tool_name="fc_loop",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="E2E_TIMEOUT",
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-07] Function Calling error: {e}", exc_info=True)
|
||||
react_ctx.final_answer = f"抱歉,处理过程中发生错误:{str(e)}"
|
||||
|
||||
total_duration_ms = int((time.time() - start_time) * 1000)
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
react_iterations=react_ctx.iteration,
|
||||
tools_used=[tc.tool_name for tc in tool_calls if tc.tool_name not in ("fc_loop", "react_loop")],
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Function Calling completed: iterations={react_ctx.iteration}, "
|
||||
f"duration_ms={total_duration_ms}"
|
||||
)
|
||||
|
||||
return react_ctx.final_answer or "抱歉,我暂时无法处理您的请求。", react_ctx, trace
|
||||
|
||||
async def _act_fc(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
react_ctx: ReActContext,
|
||||
) -> tuple[ToolResult, ToolCallTrace]:
|
||||
"""Execute tool in Function Calling mode."""
|
||||
start_time = time.time()
|
||||
|
||||
if not self._tool_registry:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Tool registry not configured",
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="NO_REGISTRY",
|
||||
)
|
||||
|
||||
try:
|
||||
final_args = dict(tool_args)
|
||||
if self._tenant_id:
|
||||
final_args["tenant_id"] = self._tenant_id
|
||||
if self._user_id:
|
||||
final_args["user_id"] = self._user_id
|
||||
if self._session_id:
|
||||
final_args["session_id"] = self._session_id
|
||||
|
||||
if tool_name == "kb_search_dynamic":
|
||||
# 确保 context 存在,供 AI 传入动态过滤条件
|
||||
if "context" not in final_args:
|
||||
final_args["context"] = {}
|
||||
# scene 参数由 AI 从元数据中选择,系统不强制覆盖
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] FC Tool call starting: tool={tool_name}, "
|
||||
f"args={tool_args}, final_args={final_args}"
|
||||
)
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
self._tool_registry.execute(
|
||||
name=tool_name,
|
||||
**final_args,
|
||||
),
|
||||
timeout=self._timeout_governor.per_tool_timeout_seconds
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
called_tools = {tc.tool_name for tc in react_ctx.tool_calls}
|
||||
is_first_call = tool_name not in called_tools
|
||||
|
||||
output = result.output
|
||||
if is_first_call and result.success:
|
||||
usage_guide = self._build_tool_usage_guide(tool_name)
|
||||
if usage_guide:
|
||||
if isinstance(output, dict):
|
||||
output = dict(output)
|
||||
output["_tool_guide"] = usage_guide
|
||||
elif isinstance(output, str):
|
||||
output = f"{output}\n\n---\n{usage_guide}"
|
||||
else:
|
||||
output = {"result": output, "_tool_guide": usage_guide}
|
||||
|
||||
return ToolResult(
|
||||
success=result.success,
|
||||
output=output,
|
||||
error=result.error,
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
|
||||
args_digest=str(tool_args)[:100] if tool_args else None,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
arguments=tool_args,
|
||||
result=output,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[AC-MARH-08] FC Tool timeout: {tool_name}, duration={duration_ms}ms")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Tool timeout",
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
arguments=tool_args,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[AC-MARH-07] FC Tool error: {tool_name}, error={e}")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="TOOL_ERROR",
|
||||
arguments=tool_args,
|
||||
)
|
||||
|
||||
async def _build_system_prompt(self) -> str:
|
||||
"""Build system prompt for Function Calling mode with template support."""
|
||||
default_prompt = """你是一个智能客服助手,正在处理用户请求。
|
||||
|
||||
## 决策协议
|
||||
|
||||
1. 优先使用已有观察信息,避免重复调用同类工具。
|
||||
2. 当问题需要外部事实或结构化状态时再调用工具;如果可直接回答则不要调用。
|
||||
3. 缺少关键参数时,优先向用户追问,不要使用空参数调用工具。
|
||||
4. 工具失败时,先说明已尝试,再给出降级方案或下一步引导。
|
||||
5. 对用户输出必须拟人、自然、有同理心,不暴露"工具调用/路由/策略"等内部术语。
|
||||
|
||||
## 知识库查询强制流程
|
||||
|
||||
当用户问题需要进行知识库查询(kb_search_dynamic)时,必须遵循以下步骤:
|
||||
|
||||
**步骤1:先调用 list_document_metadata_fields**
|
||||
- 在任何知识库搜索之前,必须先调用 `list_document_metadata_fields` 工具
|
||||
- 获取可用的元数据字段(如 grade, subject, kb_scene 等)及其常见取值
|
||||
|
||||
**步骤2:分析用户意图,选择合适的过滤条件**
|
||||
- 根据用户问题和返回的元数据字段,确定合适的过滤条件
|
||||
- 从元数据字段的 common_values 中选择合适的值
|
||||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
|
||||
**示例流程:**
|
||||
1. 调用 `list_document_metadata_fields` 获取字段信息
|
||||
2. 根据返回结果,发现可用字段:grade(年级)、subject(学科)、kb_scene(场景)
|
||||
3. 分析用户问题"三年级语文怎么学",确定过滤条件:grade="三年级", subject="语文"
|
||||
4. 从 kb_scene 的常见值中选择合适的 scene(如"学习方案")
|
||||
5. 调用 `kb_search_dynamic`,传入构造好的 context 和 scene
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
"""
|
||||
|
||||
if not self._template_service or not self._tenant_id:
|
||||
return default_prompt
|
||||
|
||||
try:
|
||||
from app.core.database import get_session
|
||||
from app.core.prompts import SYSTEM_PROMPT
|
||||
|
||||
async with get_session() as session:
|
||||
template_service = PromptTemplateService(session)
|
||||
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
scene="agent_fc",
|
||||
resolver=self._variable_resolver,
|
||||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
scene="default",
|
||||
resolver=self._variable_resolver,
|
||||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
logger.info("[AC-MARH-07] No published template found for agent_fc or default, using default prompt")
|
||||
return default_prompt
|
||||
|
||||
agent_protocol = """
|
||||
|
||||
## 智能体决策协议
|
||||
|
||||
1. 优先使用已有观察信息,避免重复调用同类工具。
|
||||
2. 当问题需要外部事实或结构化状态时再调用工具;如果可直接回答则不要调用。
|
||||
3. 缺少关键参数时,优先向用户追问,不要使用空参数调用工具。
|
||||
4. 工具失败时,先说明已尝试,再给出降级方案或下一步引导。
|
||||
|
||||
## 知识库查询强制流程
|
||||
|
||||
当用户问题需要进行知识库查询(kb_search_dynamic)时,必须遵循以下步骤:
|
||||
|
||||
**步骤1:先调用 list_document_metadata_fields**
|
||||
- 在任何知识库搜索之前,必须先调用 `list_document_metadata_fields` 工具
|
||||
- 获取可用的元数据字段(如 grade, subject, kb_scene 等)及其常见取值
|
||||
|
||||
**步骤2:分析用户意图,选择合适的过滤条件**
|
||||
- 根据用户问题和返回的元数据字段,确定合适的过滤条件
|
||||
- 从元数据字段的 common_values 中选择合适的值
|
||||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
"""
|
||||
|
||||
final_prompt = base_prompt + agent_protocol
|
||||
|
||||
logger.info(f"[AC-MARH-07] Loaded template for tenant={self._tenant_id}")
|
||||
return final_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-07] Failed to load template, using default: {e}")
|
||||
return default_prompt
|
||||
|
||||
async def _force_final_answer_fc(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Force final answer when max iterations reached in Function Calling mode."""
|
||||
try:
|
||||
response = await self._llm_client.generate(
|
||||
messages=messages + [{"role": "user", "content": "请基于以上信息给出最终回答,不要再调用工具。"}],
|
||||
tools=None,
|
||||
)
|
||||
return response.content or "抱歉,我已经尽力处理您的请求,但可能需要更多信息。"
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-07] Force final answer FC failed: {e}")
|
||||
return "抱歉,我已经尽力处理您的请求,但可能需要更多信息。请稍后重试或联系人工客服。"
|
||||
|
||||
async def _execute_react(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
on_thought: Any = None,
|
||||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
Execute using traditional ReAct mode (text-based).
|
||||
|
||||
This is the original implementation for backward compatibility.
|
||||
"""
|
||||
react_ctx = ReActContext(max_iterations=self._max_iterations)
|
||||
tool_calls: list[ToolCallTrace] = []
|
||||
start_time = time.time()
|
||||
|
|
@ -321,58 +804,108 @@ Action Input:
|
|||
return default_template
|
||||
|
||||
def _build_tools_section(self) -> str:
|
||||
"""Build rich tools section for ReAct prompt."""
|
||||
"""
|
||||
Build compact tools section for ReAct prompt.
|
||||
|
||||
Only includes tool name and brief description for initial scanning.
|
||||
Detailed usage guides are disclosed on-demand when tool is called.
|
||||
"""
|
||||
if not self._tool_registry:
|
||||
return "当前没有可用的工具。"
|
||||
|
||||
tools = self._tool_registry.list_tools(enabled_only=True)
|
||||
tools = self._tool_registry.get_all_tools()
|
||||
if not tools:
|
||||
return "当前没有可用的工具。"
|
||||
|
||||
lines = ["## 可用工具列表", "", "以下是你可以使用的工具,只能使用这些工具:", ""]
|
||||
lines = ["## 可用工具列表", "", "以下是你可以使用的工具:", ""]
|
||||
|
||||
for tool in tools:
|
||||
tool_guide = self._guide_registry.get_tool_guide(tool.name) if self._guide_registry else None
|
||||
description = tool_guide.description if tool_guide else tool.description
|
||||
lines.append(f"- **{tool.name}**: {description}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("调用工具时,系统会提供该工具的详细使用说明。")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_tool_usage_guide(self, tool_name: str) -> str:
|
||||
"""
|
||||
Build detailed usage guide for a specific tool.
|
||||
Called when the tool is executed to provide on-demand guidance.
|
||||
"""
|
||||
tool_guide = self._guide_registry.get_tool_guide(tool_name) if self._guide_registry else None
|
||||
tool = self._tool_registry.get_tool(tool_name) if self._tool_registry else None
|
||||
|
||||
if not tool_guide and not tool:
|
||||
return ""
|
||||
|
||||
lines = [f"## {tool_name} 使用说明", ""]
|
||||
|
||||
if tool_guide:
|
||||
lines.append(f"**用途**: {tool_guide.description}")
|
||||
lines.append("")
|
||||
|
||||
if tool_guide.triggers:
|
||||
lines.append("**适用场景**:")
|
||||
for trigger in tool_guide.triggers:
|
||||
lines.append(f"- {trigger}")
|
||||
lines.append("")
|
||||
|
||||
if tool_guide.anti_triggers:
|
||||
lines.append("**不适用场景**:")
|
||||
for anti in tool_guide.anti_triggers:
|
||||
lines.append(f"- {anti}")
|
||||
lines.append("")
|
||||
|
||||
if tool_guide.content:
|
||||
lines.append(tool_guide.content)
|
||||
lines.append("")
|
||||
|
||||
if tool:
|
||||
meta = tool.metadata or {}
|
||||
lines.append(f"### {tool.name}")
|
||||
lines.append(f"用途: {tool.description}")
|
||||
|
||||
when_to_use = meta.get("when_to_use")
|
||||
when_not_to_use = meta.get("when_not_to_use")
|
||||
if when_to_use:
|
||||
lines.append(f"何时使用: {when_to_use}")
|
||||
if when_not_to_use:
|
||||
lines.append(f"何时不要使用: {when_not_to_use}")
|
||||
|
||||
params = meta.get("parameters")
|
||||
if isinstance(params, dict):
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
if properties:
|
||||
lines.append("参数:")
|
||||
lines.append("**参数说明**:")
|
||||
for param_name, param_info in properties.items():
|
||||
param_desc = param_info.get("description", "") if isinstance(param_info, dict) else ""
|
||||
line = f" - {param_name}: {param_desc}".strip()
|
||||
req_mark = " (必填)" if param_name in required else ""
|
||||
if param_name == "tenant_id":
|
||||
line += " (系统注入,模型不要填写)"
|
||||
elif param_name in required:
|
||||
line += " (必填)"
|
||||
lines.append(line)
|
||||
req_mark = " (系统注入)"
|
||||
lines.append(f"- `{param_name}`: {param_desc}{req_mark}")
|
||||
lines.append("")
|
||||
|
||||
if meta.get("example_action_input"):
|
||||
lines.append("示例入参(JSON):")
|
||||
lines.append("**调用示例**:")
|
||||
try:
|
||||
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False)
|
||||
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
example_text = str(meta["example_action_input"])
|
||||
lines.append(example_text)
|
||||
lines.append(f"```json\n{example_text}\n```")
|
||||
lines.append("")
|
||||
|
||||
if meta.get("result_interpretation"):
|
||||
lines.append(f"结果解释: {meta['result_interpretation']}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**结果说明**: {meta['result_interpretation']}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_tools_guide_section(self, tool_names: list[str] | None = None) -> str:
|
||||
"""
|
||||
Build detailed tools guide section for ReAct prompt.
|
||||
|
||||
This provides comprehensive usage guides from ToolRegistry.
|
||||
Called separately from _build_tools_section for flexibility.
|
||||
|
||||
Args:
|
||||
tool_names: If provided, only include tools for these names.
|
||||
If None, include all tools.
|
||||
"""
|
||||
return self._guide_registry.build_tools_prompt_section(tool_names)
|
||||
|
||||
def _extract_json_object(self, text: str) -> dict[str, Any] | None:
|
||||
"""Extract the first valid JSON object from free text."""
|
||||
candidates = []
|
||||
|
|
@ -438,6 +971,8 @@ Action Input:
|
|||
) -> tuple[ToolResult, ToolCallTrace]:
|
||||
"""
|
||||
[AC-MARH-07, AC-MARH-08] Execute tool action with timeout.
|
||||
|
||||
On first call to a tool, appends detailed usage guide to observation.
|
||||
"""
|
||||
tool_name = thought.action or "unknown"
|
||||
start_time = time.time()
|
||||
|
|
@ -461,6 +996,23 @@ Action Input:
|
|||
if self._tenant_id:
|
||||
tool_args["tenant_id"] = self._tenant_id
|
||||
|
||||
if self._user_id:
|
||||
tool_args["user_id"] = self._user_id
|
||||
|
||||
if self._session_id:
|
||||
tool_args["session_id"] = self._session_id
|
||||
|
||||
if tool_name == "kb_search_dynamic":
|
||||
# 确保 context 存在,供 AI 传入动态过滤条件
|
||||
if "context" not in tool_args:
|
||||
tool_args["context"] = {}
|
||||
# scene 参数由 AI 从元数据中选择,系统不强制覆盖
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Tool call starting: tool={tool_name}, "
|
||||
f"action_input={thought.action_input}, final_args={tool_args}"
|
||||
)
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
self._tool_registry.execute(
|
||||
tool_name=tool_name,
|
||||
|
|
@ -470,9 +1022,25 @@ Action Input:
|
|||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
called_tools = {tc.tool_name for tc in react_ctx.tool_calls}
|
||||
is_first_call = tool_name not in called_tools
|
||||
|
||||
output = result.output
|
||||
if is_first_call and result.success:
|
||||
usage_guide = self._build_tool_usage_guide(tool_name)
|
||||
if usage_guide:
|
||||
if isinstance(output, dict):
|
||||
output = dict(output)
|
||||
output["_tool_guide"] = usage_guide
|
||||
elif isinstance(output, str):
|
||||
output = f"{output}\n\n---\n{usage_guide}"
|
||||
else:
|
||||
output = {"result": output, "_tool_guide": usage_guide}
|
||||
|
||||
return ToolResult(
|
||||
success=result.success,
|
||||
output=result.output,
|
||||
output=output,
|
||||
error=result.error,
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
|
|
@ -482,6 +1050,8 @@ Action Input:
|
|||
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
|
||||
args_digest=str(thought.action_input)[:100] if thought.action_input else None,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
arguments=thought.action_input,
|
||||
result=output,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -497,6 +1067,7 @@ Action Input:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
arguments=thought.action_input,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -512,6 +1083,7 @@ Action Input:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="TOOL_ERROR",
|
||||
arguments=thought.action_input,
|
||||
)
|
||||
|
||||
async def _force_final_answer(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,342 @@
|
|||
"""
|
||||
Batch Ask-Back Service.
|
||||
批量追问服务 - 支持一次追问多个缺失槽位
|
||||
|
||||
[AC-MRS-SLOT-ASKBACK-01] 批量追问
|
||||
|
||||
职责:
|
||||
1. 支持一次追问多个缺失槽位
|
||||
2. 选择策略:必填优先、场景相关优先、最近未追问过优先
|
||||
3. 输出形式:单条自然语言合并提问或分段提问
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.cache.slot_state_cache import get_slot_state_cache
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskBackSlot:
|
||||
"""
|
||||
追问槽位信息
|
||||
|
||||
Attributes:
|
||||
slot_key: 槽位键名
|
||||
label: 显示标签
|
||||
ask_back_prompt: 追问提示
|
||||
priority: 优先级(数值越大越优先)
|
||||
last_asked_at: 上次追问时间戳
|
||||
is_required: 是否必填
|
||||
scene_relevance: 场景相关度
|
||||
"""
|
||||
slot_key: str
|
||||
label: str
|
||||
ask_back_prompt: str | None = None
|
||||
priority: int = 0
|
||||
last_asked_at: float | None = None
|
||||
is_required: bool = False
|
||||
scene_relevance: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchAskBackConfig:
|
||||
"""
|
||||
批量追问配置
|
||||
|
||||
Attributes:
|
||||
max_ask_back_slots_per_turn: 每轮最多追问槽位数
|
||||
prefer_required: 是否优先追问必填槽位
|
||||
prefer_scene_relevant: 是否优先追问场景相关槽位
|
||||
avoid_recent_asked: 是否避免最近追问过的槽位
|
||||
recent_asked_threshold_seconds: 最近追问阈值(秒)
|
||||
merge_prompts: 是否合并追问提示
|
||||
merge_template: 合并模板
|
||||
"""
|
||||
max_ask_back_slots_per_turn: int = 2
|
||||
prefer_required: bool = True
|
||||
prefer_scene_relevant: bool = True
|
||||
avoid_recent_asked: bool = True
|
||||
recent_asked_threshold_seconds: float = 60.0
|
||||
merge_prompts: bool = True
|
||||
merge_template: str = "为了更好地为您服务,请告诉我:{prompts}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchAskBackResult:
|
||||
"""
|
||||
批量追问结果
|
||||
|
||||
Attributes:
|
||||
selected_slots: 选中的追问槽位列表
|
||||
prompts: 追问提示列表
|
||||
merged_prompt: 合并后的追问提示
|
||||
ask_back_count: 追问数量
|
||||
"""
|
||||
selected_slots: list[AskBackSlot] = field(default_factory=list)
|
||||
prompts: list[str] = field(default_factory=list)
|
||||
merged_prompt: str | None = None
|
||||
ask_back_count: int = 0
|
||||
|
||||
def has_ask_back(self) -> bool:
|
||||
return self.ask_back_count > 0
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""获取最终追问提示"""
|
||||
if self.merged_prompt:
|
||||
return self.merged_prompt
|
||||
if self.prompts:
|
||||
return self.prompts[0]
|
||||
return "请提供更多信息以便我更好地帮助您。"
|
||||
|
||||
|
||||
class BatchAskBackService:
|
||||
"""
|
||||
[AC-MRS-SLOT-ASKBACK-01] 批量追问服务
|
||||
|
||||
支持一次追问多个缺失槽位,提高补全效率。
|
||||
"""
|
||||
|
||||
ASK_BACK_HISTORY_KEY_PREFIX = "slot_ask_back_history"
|
||||
ASK_BACK_HISTORY_TTL = 300
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
config: BatchAskBackConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._config = config or BatchAskBackConfig()
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._cache = get_slot_state_cache()
|
||||
|
||||
async def generate_batch_ask_back(
|
||||
self,
|
||||
missing_slots: list[dict[str, str]],
|
||||
current_scene: str | None = None,
|
||||
) -> BatchAskBackResult:
|
||||
"""
|
||||
生成批量追问
|
||||
|
||||
Args:
|
||||
missing_slots: 缺失槽位列表
|
||||
current_scene: 当前场景
|
||||
|
||||
Returns:
|
||||
BatchAskBackResult: 批量追问结果
|
||||
"""
|
||||
if not missing_slots:
|
||||
return BatchAskBackResult()
|
||||
|
||||
ask_back_slots = await self._prepare_ask_back_slots(missing_slots, current_scene)
|
||||
|
||||
selected_slots = self._select_slots_for_ask_back(ask_back_slots)
|
||||
|
||||
asked_history = await self._get_asked_history()
|
||||
selected_slots = self._filter_recently_asked(selected_slots, asked_history)
|
||||
|
||||
if not selected_slots:
|
||||
selected_slots = ask_back_slots[:self._config.max_ask_back_slots_per_turn]
|
||||
|
||||
prompts = self._generate_prompts(selected_slots)
|
||||
merged_prompt = self._merge_prompts(prompts) if self._config.merge_prompts else None
|
||||
|
||||
await self._record_ask_back_history([s.slot_key for s in selected_slots])
|
||||
|
||||
return BatchAskBackResult(
|
||||
selected_slots=selected_slots,
|
||||
prompts=prompts,
|
||||
merged_prompt=merged_prompt,
|
||||
ask_back_count=len(selected_slots),
|
||||
)
|
||||
|
||||
async def _prepare_ask_back_slots(
|
||||
self,
|
||||
missing_slots: list[dict[str, str]],
|
||||
current_scene: str | None,
|
||||
) -> list[AskBackSlot]:
|
||||
"""准备追问槽位列表"""
|
||||
ask_back_slots = []
|
||||
|
||||
for missing in missing_slots:
|
||||
slot_key = missing.get("slot_key", "")
|
||||
label = missing.get("label", slot_key)
|
||||
ask_back_prompt = missing.get("ask_back_prompt")
|
||||
field_key = missing.get("field_key")
|
||||
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
|
||||
is_required = False
|
||||
scene_relevance = 0.0
|
||||
|
||||
if slot_def:
|
||||
is_required = slot_def.required
|
||||
if not ask_back_prompt:
|
||||
ask_back_prompt = slot_def.ask_back_prompt
|
||||
|
||||
if current_scene and slot_def.scene_scope:
|
||||
if current_scene in slot_def.scene_scope:
|
||||
scene_relevance = 1.0
|
||||
|
||||
priority = self._calculate_priority(is_required, scene_relevance)
|
||||
|
||||
ask_back_slots.append(AskBackSlot(
|
||||
slot_key=slot_key,
|
||||
label=label,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
priority=priority,
|
||||
is_required=is_required,
|
||||
scene_relevance=scene_relevance,
|
||||
))
|
||||
|
||||
return ask_back_slots
|
||||
|
||||
def _calculate_priority(self, is_required: bool, scene_relevance: float) -> int:
|
||||
"""计算槽位优先级"""
|
||||
priority = 0
|
||||
|
||||
if self._config.prefer_required and is_required:
|
||||
priority += 100
|
||||
|
||||
if self._config.prefer_scene_relevant:
|
||||
priority += int(scene_relevance * 50)
|
||||
|
||||
return priority
|
||||
|
||||
def _select_slots_for_ask_back(
|
||||
self,
|
||||
ask_back_slots: list[AskBackSlot],
|
||||
) -> list[AskBackSlot]:
|
||||
"""选择要追问的槽位"""
|
||||
sorted_slots = sorted(
|
||||
ask_back_slots,
|
||||
key=lambda s: s.priority,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_slots[:self._config.max_ask_back_slots_per_turn]
|
||||
|
||||
async def _get_asked_history(self) -> dict[str, float]:
|
||||
"""获取最近追问历史"""
|
||||
history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}"
|
||||
|
||||
try:
|
||||
client = await self._cache._get_client()
|
||||
if client is None:
|
||||
return {}
|
||||
|
||||
import json
|
||||
data = await client.get(history_key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"[BatchAskBack] Failed to get asked history: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
async def _record_ask_back_history(self, slot_keys: list[str]) -> None:
|
||||
"""记录追问历史"""
|
||||
history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}"
|
||||
|
||||
try:
|
||||
client = await self._cache._get_client()
|
||||
if client is None:
|
||||
return
|
||||
|
||||
history = await self._get_asked_history()
|
||||
current_time = time.time()
|
||||
|
||||
for slot_key in slot_keys:
|
||||
history[slot_key] = current_time
|
||||
|
||||
import json
|
||||
await client.setex(
|
||||
history_key,
|
||||
self.ASK_BACK_HISTORY_TTL,
|
||||
json.dumps(history),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[BatchAskBack] Failed to record asked history: {e}")
|
||||
|
||||
def _filter_recently_asked(
|
||||
self,
|
||||
slots: list[AskBackSlot],
|
||||
asked_history: dict[str, float],
|
||||
) -> list[AskBackSlot]:
|
||||
"""过滤最近追问过的槽位"""
|
||||
if not self._config.avoid_recent_asked:
|
||||
return slots
|
||||
|
||||
current_time = time.time()
|
||||
threshold = self._config.recent_asked_threshold_seconds
|
||||
|
||||
return [
|
||||
slot for slot in slots
|
||||
if slot.slot_key not in asked_history or
|
||||
current_time - asked_history[slot.slot_key] > threshold
|
||||
]
|
||||
|
||||
def _generate_prompts(self, slots: list[AskBackSlot]) -> list[str]:
|
||||
"""生成追问提示列表"""
|
||||
prompts = []
|
||||
|
||||
for slot in slots:
|
||||
if slot.ask_back_prompt:
|
||||
prompts.append(slot.ask_back_prompt)
|
||||
else:
|
||||
prompts.append(f"请告诉我您的{slot.label}")
|
||||
|
||||
return prompts
|
||||
|
||||
def _merge_prompts(self, prompts: list[str]) -> str | None:
|
||||
"""合并追问提示"""
|
||||
if not prompts:
|
||||
return None
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
|
||||
if len(prompts) == 2:
|
||||
return f"{prompts[0]},以及{prompts[1]}"
|
||||
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"{all_but_last},以及{prompts[-1]}"
|
||||
|
||||
|
||||
def create_batch_ask_back_service(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
config: BatchAskBackConfig | None = None,
|
||||
) -> BatchAskBackService:
|
||||
"""
|
||||
创建批量追问服务实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
config: 配置
|
||||
|
||||
Returns:
|
||||
BatchAskBackService: 批量追问服务实例
|
||||
"""
|
||||
return BatchAskBackService(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
config=config,
|
||||
)
|
||||
|
|
@ -17,15 +17,20 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
|
||||
from app.services.mid.metadata_filter_builder import (
|
||||
FilterBuildResult,
|
||||
FilterFieldInfo,
|
||||
MetadataFilterBuilder,
|
||||
)
|
||||
from app.services.mid.slot_state_aggregator import (
|
||||
SlotState,
|
||||
SlotStateAggregator,
|
||||
)
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -37,6 +42,9 @@ DEFAULT_TOP_K = 5
|
|||
DEFAULT_TIMEOUT_MS = 2000
|
||||
KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic"
|
||||
|
||||
_TOOL_SCHEMA_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||
_TOOL_SCHEMA_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KbSearchDynamicResult:
|
||||
|
|
@ -46,9 +54,21 @@ class KbSearchDynamicResult:
|
|||
applied_filter: dict[str, Any] = field(default_factory=dict)
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
filter_debug: dict[str, Any] = field(default_factory=dict)
|
||||
filter_sources: dict[str, str] = field(default_factory=dict) # [AC-SCENE-SLOT-02] 过滤条件来源
|
||||
fallback_reason_code: str | None = None
|
||||
duration_ms: int = 0
|
||||
tool_trace: ToolCallTrace | None = None
|
||||
step_kb_binding: dict[str, Any] | None = None # [Step-KB-Binding] 步骤知识库绑定信息
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepKbConfig:
|
||||
"""[Step-KB-Binding] 步骤级别的知识库配置。"""
|
||||
allowed_kb_ids: list[str] | None = None
|
||||
preferred_kb_ids: list[str] | None = None
|
||||
kb_query_hint: str | None = None
|
||||
max_kb_calls: int = 1
|
||||
step_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -86,12 +106,14 @@ class KbSearchDynamicTool:
|
|||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbSearchDynamicConfig | None = None,
|
||||
slot_state_aggregator: SlotStateAggregator | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or KbSearchDynamicConfig()
|
||||
self._vector_retriever = None
|
||||
self._filter_builder: MetadataFilterBuilder | None = None
|
||||
self._slot_state_aggregator = slot_state_aggregator
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -140,6 +162,128 @@ class KbSearchDynamicTool:
|
|||
},
|
||||
}
|
||||
|
||||
async def get_dynamic_tool_schema(self, tenant_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取动态生成的工具 Schema,包含租户的元数据过滤字段。
|
||||
|
||||
使用缓存机制,避免每次都查询数据库。
|
||||
只显示关联知识库的元数据过滤字段(field_roles 包含 resource_filter)。
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
动态生成的工具 Schema
|
||||
"""
|
||||
import time
|
||||
current_time = time.time()
|
||||
|
||||
cache_key = f"tool_schema:{tenant_id}"
|
||||
if cache_key in _TOOL_SCHEMA_CACHE:
|
||||
cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key]
|
||||
if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS:
|
||||
logger.debug(f"[AC-MARH-05] Tool schema cache hit for tenant={tenant_id}")
|
||||
return cached_schema
|
||||
|
||||
logger.info(f"[AC-MARH-05] Building dynamic tool schema for tenant={tenant_id}")
|
||||
|
||||
base_properties = {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "检索查询文本",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量",
|
||||
"default": DEFAULT_TOP_K,
|
||||
},
|
||||
}
|
||||
|
||||
required_fields = ["query"]
|
||||
context_properties = {}
|
||||
|
||||
try:
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id)
|
||||
|
||||
for field_info in filterable_fields:
|
||||
field_schema = self._build_field_schema(field_info)
|
||||
context_properties[field_info.field_key] = field_schema
|
||||
|
||||
if field_info.required:
|
||||
required_fields.append(field_info.field_key)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Dynamic schema built: tenant={tenant_id}, "
|
||||
f"context_fields={len(context_properties)}, required={required_fields}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-05] Failed to get filterable fields: {e}, using base schema")
|
||||
|
||||
if context_properties:
|
||||
base_properties["context"] = {
|
||||
"type": "object",
|
||||
"description": "过滤条件,根据用户意图选择合适的字段传递",
|
||||
"properties": context_properties,
|
||||
}
|
||||
else:
|
||||
base_properties["context"] = {
|
||||
"type": "object",
|
||||
"description": "过滤条件(当前租户未配置元数据字段)",
|
||||
}
|
||||
|
||||
schema = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": base_properties,
|
||||
"required": required_fields,
|
||||
},
|
||||
}
|
||||
|
||||
_TOOL_SCHEMA_CACHE[cache_key] = (current_time, schema)
|
||||
|
||||
return schema
|
||||
|
||||
def _build_field_schema(self, field_info: "FilterFieldInfo") -> dict[str, Any]:
|
||||
"""
|
||||
根据字段信息构建 JSON Schema。
|
||||
|
||||
Args:
|
||||
field_info: 字段信息
|
||||
|
||||
Returns:
|
||||
字段的 JSON Schema
|
||||
"""
|
||||
schema: dict[str, Any] = {
|
||||
"description": field_info.label or field_info.field_key,
|
||||
}
|
||||
|
||||
if field_info.required:
|
||||
schema["description"] += "(必填)"
|
||||
|
||||
field_type = field_info.field_type.lower() if field_info.field_type else "string"
|
||||
|
||||
if field_type in ("enum", "select", "array_enum", "multi_select"):
|
||||
schema["type"] = "string"
|
||||
if field_info.options:
|
||||
schema["enum"] = field_info.options
|
||||
schema["description"] += f",可选值:{', '.join(field_info.options)}"
|
||||
elif field_type in ("number", "integer", "float"):
|
||||
schema["type"] = "number"
|
||||
elif field_type == "boolean":
|
||||
schema["type"] = "boolean"
|
||||
else:
|
||||
schema["type"] = "string"
|
||||
|
||||
if field_info.default_value is not None:
|
||||
schema["default"] = field_info.default_value
|
||||
|
||||
return schema
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -147,16 +291,24 @@ class KbSearchDynamicTool:
|
|||
scene: str = "open_consult",
|
||||
top_k: int | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
slot_state: SlotState | None = None,
|
||||
step_kb_config: StepKbConfig | None = None,
|
||||
slot_policy: Literal["flow_strict", "agent_relaxed"] = "flow_strict",
|
||||
) -> KbSearchDynamicResult:
|
||||
"""
|
||||
[AC-MARH-05] 执行 KB 动态检索。
|
||||
[AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级
|
||||
[Step-KB-Binding] 支持步骤级别的知识库约束
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识
|
||||
scene: 场景标识(默认值,会被 context 中的 scene 覆盖)
|
||||
top_k: 返回数量
|
||||
context: 上下文(包含动态过滤值)
|
||||
context: 上下文(包含动态过滤值,包括 scene)
|
||||
slot_state: 预聚合的槽位状态(可选,优先使用)
|
||||
step_kb_config: 步骤级别的知识库配置(可选)
|
||||
slot_policy: 槽位策略(flow_strict=流程严格模式,agent_relaxed=通用问答宽松模式)
|
||||
|
||||
Returns:
|
||||
KbSearchDynamicResult 包含检索结果和追踪信息
|
||||
|
|
@ -171,82 +323,150 @@ class KbSearchDynamicTool:
|
|||
start_time = time.time()
|
||||
top_k = top_k or self._config.top_k
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..., scene={scene}, top_k={top_k}"
|
||||
)
|
||||
effective_context = dict(context) if context else {}
|
||||
effective_scene = effective_context.get("scene", scene)
|
||||
|
||||
filter_result: FilterBuildResult | None = None
|
||||
|
||||
try:
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filter_result = await self._filter_builder.build_filter(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
# [Step-KB-Binding] 记录步骤知识库约束
|
||||
step_kb_binding_info: dict[str, Any] = {}
|
||||
if step_kb_config:
|
||||
step_kb_binding_info = {
|
||||
"step_id": step_kb_config.step_id,
|
||||
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
|
||||
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
|
||||
"kb_query_hint": step_kb_config.kb_query_hint,
|
||||
"max_kb_calls": step_kb_config.max_kb_calls,
|
||||
}
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] Step KB config applied: "
|
||||
f"step_id={step_kb_config.step_id}, "
|
||||
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}, "
|
||||
f"preferred_kb_ids={step_kb_config.preferred_kb_ids}"
|
||||
)
|
||||
|
||||
if filter_result.missing_required_slots:
|
||||
logger.warning(
|
||||
f"[AC-MARH-05] Missing required slots: "
|
||||
f"{filter_result.missing_required_slots}"
|
||||
)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(
|
||||
f"[AC-MARH-05] 开始执行KB动态检索: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..., scene={effective_scene}, top_k={top_k}, "
|
||||
f"slot_policy={slot_policy}, context_keys={list(effective_context.keys())}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={scene}",
|
||||
result_digest=f"missing={len(filter_result.missing_required_slots)}",
|
||||
# [AC-MRS-SLOT-META-02] 如果提供了 slot_state,优先使用
|
||||
if slot_state is not None:
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-02] Using provided slot_state: "
|
||||
f"filled={len(slot_state.filled_slots)}, "
|
||||
f"missing={len(slot_state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
# 检查是否有缺失的必填槽位(仅在流程严格模式下阻断)
|
||||
if slot_state.missing_required_slots:
|
||||
if slot_policy == "flow_strict":
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 流程严格模式命中缺失必填槽位,触发追问: "
|
||||
f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"missing={len(slot_state.missing_required_slots)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"missing_required_slots": slot_state.missing_required_slots},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter={},
|
||||
missing_required_slots=slot_state.missing_required_slots,
|
||||
filter_debug={
|
||||
"source": "slot_state",
|
||||
"filled_slots": slot_state.filled_slots,
|
||||
"slot_to_field_map": slot_state.slot_to_field_map,
|
||||
},
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 通用问答宽松模式检测到缺失槽位但不阻断检索: "
|
||||
f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter,
|
||||
missing_required_slots=filter_result.missing_required_slots,
|
||||
filter_debug=filter_result.debug_info,
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
# 使用 slot_state 构建 filter
|
||||
metadata_filter, filter_sources = await self._build_filter_from_slot_state(
|
||||
tenant_id=tenant_id,
|
||||
slot_state=slot_state,
|
||||
context=effective_context,
|
||||
scene_slot_context=effective_context.get("scene_slot_context"), # [AC-SCENE-SLOT-02]
|
||||
)
|
||||
else:
|
||||
# 原有逻辑:构建元数据 filter
|
||||
# 如果 context 简单(只有键值对),直接构造 filter,跳过 MetadataFilterBuilder
|
||||
metadata_filter = await self._build_filter_legacy(
|
||||
tenant_id=tenant_id,
|
||||
context=effective_context,
|
||||
query=query,
|
||||
effective_scene=effective_scene,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
metadata_filter = filter_result.applied_filter if filter_result.success else None
|
||||
if isinstance(metadata_filter, KbSearchDynamicResult):
|
||||
# 有错误,直接返回
|
||||
return metadata_filter
|
||||
|
||||
try:
|
||||
hits = await self._retrieve_with_timeout(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata_filter=metadata_filter,
|
||||
top_k=top_k,
|
||||
step_kb_config=step_kb_config,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
kb_hit = len(hits) > 0
|
||||
|
||||
# [Step-KB-Binding] 记录命中的知识库
|
||||
hit_kb_ids = list(set(hit.get("kb_id") for hit in hits if hit.get("kb_id")))
|
||||
if step_kb_binding_info:
|
||||
step_kb_binding_info["used_kb_ids"] = hit_kb_ids
|
||||
step_kb_binding_info["kb_hit"] = kb_hit
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK,
|
||||
args_digest=f"query={query[:50]}, scene={scene}",
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"hits={len(hits)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"hits_count": len(hits), "kb_hit": kb_hit},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, "
|
||||
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
|
||||
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}, "
|
||||
f"hit_kb_ids={hit_kb_ids}"
|
||||
)
|
||||
|
||||
# 确定 filter 来源用于调试
|
||||
filter_source = "slot_state" if slot_state is not None else "builder"
|
||||
kb_filter_sources = filter_sources if slot_state is not None else {}
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=True,
|
||||
hits=hits,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
filter_debug=filter_result.debug_info if filter_result else {},
|
||||
applied_filter=metadata_filter or {},
|
||||
filter_debug={"source": filter_source, "filter_sources": kb_filter_sources},
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
step_kb_binding=step_kb_binding_info if step_kb_binding_info else None,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -262,16 +482,18 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="KB_TIMEOUT",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
|
||||
filter_debug=filter_result.debug_info if filter_result else {},
|
||||
applied_filter=metadata_filter or {},
|
||||
missing_required_slots=[],
|
||||
filter_debug={"error": "timeout"},
|
||||
fallback_reason_code="KB_TIMEOUT",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
step_kb_binding=step_kb_binding_info if step_kb_binding_info else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -287,31 +509,224 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="KB_ERROR",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
|
||||
applied_filter=metadata_filter or {},
|
||||
missing_required_slots=[],
|
||||
filter_debug={"error": str(e)},
|
||||
fallback_reason_code="KB_ERROR",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
async def _build_filter_legacy(
|
||||
self,
|
||||
tenant_id: str,
|
||||
context: dict[str, Any],
|
||||
query: str,
|
||||
effective_scene: str,
|
||||
start_time: float,
|
||||
) -> dict[str, Any] | KbSearchDynamicResult:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-02] 原有逻辑:构建元数据 filter
|
||||
|
||||
Returns:
|
||||
dict: 构建成功的 filter
|
||||
KbSearchDynamicResult: 构建失败时的错误结果
|
||||
"""
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
# 简单 context:直接构造 filter(信任 AI 传入的值)
|
||||
# 复杂场景:使用 MetadataFilterBuilder 进行严格验证
|
||||
is_simple_context = all(
|
||||
isinstance(v, (str, int, float, bool))
|
||||
for v in context.values()
|
||||
)
|
||||
|
||||
if is_simple_context:
|
||||
# 直接构造 filter,不查询数据库
|
||||
metadata_filter = context
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Using simple context as filter directly: "
|
||||
f"{metadata_filter}"
|
||||
)
|
||||
else:
|
||||
# 复杂 context,使用 MetadataFilterBuilder
|
||||
filter_result = await self._build_filter_with_builder(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
query=query,
|
||||
effective_scene=effective_scene,
|
||||
start_time=start_time,
|
||||
)
|
||||
if isinstance(filter_result, KbSearchDynamicResult):
|
||||
# 有错误,直接返回
|
||||
return filter_result
|
||||
metadata_filter = filter_result
|
||||
|
||||
return metadata_filter or {}
|
||||
|
||||
async def _build_filter_from_slot_state(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_state: SlotState,
|
||||
context: dict[str, Any],
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> tuple[dict[str, Any], dict[str, str]]:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-02] 基于槽位状态构建过滤条件
|
||||
[AC-SCENE-SLOT-02] 支持场景槽位包配置的优先级
|
||||
|
||||
过滤值来源优先级:
|
||||
1. 已确认槽位值(slot_state.filled_slots)
|
||||
2. 当前请求 context 显式值
|
||||
3. 元数据默认值
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_state: 槽位状态
|
||||
context: 上下文
|
||||
scene_slot_context: 场景槽位上下文
|
||||
|
||||
Returns:
|
||||
(过滤条件字典, 过滤来源映射)
|
||||
"""
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
# 获取可过滤字段定义
|
||||
filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id)
|
||||
|
||||
applied_filter: dict[str, Any] = {}
|
||||
filter_debug_sources: dict[str, str] = {}
|
||||
|
||||
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,优先处理场景定义的槽位
|
||||
scene_slot_keys = set()
|
||||
if scene_slot_context:
|
||||
scene_slot_keys = set(scene_slot_context.get_all_slot_keys())
|
||||
logger.debug(
|
||||
f"[AC-SCENE-SLOT-02] Processing scene slots: "
|
||||
f"scene={scene_slot_context.scene_key}, slots={scene_slot_keys}"
|
||||
)
|
||||
|
||||
for field_info in filterable_fields:
|
||||
field_key = field_info.field_key
|
||||
value = None
|
||||
source = None
|
||||
|
||||
# 优先级 1: 已确认槽位值(通过 slot_to_field_map 映射)
|
||||
if slot_state.slot_to_field_map:
|
||||
# 查找哪个 slot 映射到这个 field
|
||||
for slot_key, mapped_field_key in slot_state.slot_to_field_map.items():
|
||||
if mapped_field_key == field_key and slot_key in slot_state.filled_slots:
|
||||
value = slot_state.filled_slots[slot_key]
|
||||
source = "slot"
|
||||
break
|
||||
|
||||
# 如果 slot 映射没有命中,直接检查 slot_key 是否等于 field_key
|
||||
if value is None and field_key in slot_state.filled_slots:
|
||||
value = slot_state.filled_slots[field_key]
|
||||
source = "slot"
|
||||
|
||||
# 优先级 2: 当前请求 context 显式值
|
||||
if value is None and field_key in context:
|
||||
value = context[field_key]
|
||||
source = "context"
|
||||
|
||||
# 优先级 3: 元数据默认值
|
||||
if value is None and field_info.default_value is not None:
|
||||
value = field_info.default_value
|
||||
source = "default"
|
||||
|
||||
# 构建过滤条件
|
||||
if value is not None:
|
||||
filter_value = self._filter_builder._build_field_filter(
|
||||
field_info, value
|
||||
)
|
||||
if filter_value is not None:
|
||||
applied_filter[field_key] = filter_value
|
||||
filter_debug_sources[field_key] = source
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-02] Filter built from slot_state: "
|
||||
f"fields={len(applied_filter)}, sources={filter_debug_sources}"
|
||||
)
|
||||
|
||||
return applied_filter, filter_debug_sources
|
||||
|
||||
async def _build_filter_with_builder(
|
||||
self,
|
||||
tenant_id: str,
|
||||
context: dict[str, Any],
|
||||
query: str,
|
||||
effective_scene: str,
|
||||
start_time: float,
|
||||
) -> dict[str, Any] | KbSearchDynamicResult:
|
||||
"""
|
||||
使用 MetadataFilterBuilder 构建 filter(复杂场景)。
|
||||
|
||||
Returns:
|
||||
dict: 构建成功的 filter
|
||||
KbSearchDynamicResult: 构建失败时的错误结果
|
||||
"""
|
||||
from app.services.mid.metadata_filter_builder import FilterBuildResult
|
||||
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filter_result = await self._filter_builder.build_filter(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if filter_result.missing_required_slots:
|
||||
logger.warning(
|
||||
f"[AC-MARH-05] Missing required slots: "
|
||||
f"{filter_result.missing_required_slots}"
|
||||
)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"missing={len(filter_result.missing_required_slots)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"missing_required_slots": filter_result.missing_required_slots},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter,
|
||||
missing_required_slots=filter_result.missing_required_slots,
|
||||
filter_debug=filter_result.debug_info,
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
return filter_result.applied_filter if filter_result.success else {}
|
||||
|
||||
async def _retrieve_with_timeout(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
step_kb_config: StepKbConfig | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""带超时控制的检索。"""
|
||||
timeout_seconds = self._config.timeout_ms / 1000.0
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._do_retrieve(tenant_id, query, metadata_filter, top_k),
|
||||
self._do_retrieve(tenant_id, query, metadata_filter, top_k, step_kb_config),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -323,18 +738,36 @@ class KbSearchDynamicTool:
|
|||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
step_kb_config: StepKbConfig | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""执行实际检索。"""
|
||||
"""执行实际检索。[Step-KB-Binding] 支持步骤级别的知识库约束。"""
|
||||
if self._vector_retriever is None:
|
||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||
self._vector_retriever = await get_vector_retriever()
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
# [Step-KB-Binding] 确定要检索的知识库范围
|
||||
kb_ids = None
|
||||
if step_kb_config:
|
||||
# 如果配置了 allowed_kb_ids,则只检索这些知识库
|
||||
if step_kb_config.allowed_kb_ids:
|
||||
kb_ids = step_kb_config.allowed_kb_ids
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] Restricting KB search to: {kb_ids}"
|
||||
)
|
||||
# 如果只配置了 preferred_kb_ids,优先检索这些知识库
|
||||
elif step_kb_config.preferred_kb_ids:
|
||||
kb_ids = step_kb_config.preferred_kb_ids
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] Preferring KB search in: {kb_ids}"
|
||||
)
|
||||
|
||||
ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata=metadata_filter,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
|
||||
result = await self._vector_retriever.retrieve(ctx)
|
||||
|
|
@ -347,6 +780,7 @@ class KbSearchDynamicTool:
|
|||
"content": hit.text,
|
||||
"score": hit.score,
|
||||
"metadata": hit.metadata,
|
||||
"kb_id": hit.metadata.get("kb_id"),
|
||||
})
|
||||
|
||||
return hits[:top_k]
|
||||
|
|
@ -384,6 +818,7 @@ async def create_kb_search_dynamic_handler(
|
|||
scene: str = "open_consult",
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
KB 动态检索 handler。
|
||||
|
|
@ -391,9 +826,9 @@ async def create_kb_search_dynamic_handler(
|
|||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识
|
||||
scene: 场景标识(默认值,会被 context 中的 scene 覆盖)
|
||||
top_k: 返回数量
|
||||
context: 上下文
|
||||
context: 上下文(包含 scene 等过滤字段)
|
||||
|
||||
Returns:
|
||||
检索结果字典
|
||||
|
|
@ -442,6 +877,7 @@ def register_kb_search_dynamic_tool(
|
|||
scene: str = "open_consult",
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||
) -> dict[str, Any]:
|
||||
tool = KbSearchDynamicTool(
|
||||
session=session,
|
||||
|
|
|
|||
|
|
@ -542,9 +542,9 @@ def register_memory_recall_tool(
|
|||
cfg = config or MemoryRecallConfig()
|
||||
|
||||
async def memory_recall_handler(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
session_id: str = "",
|
||||
recall_scope: list[str] | None = None,
|
||||
max_recent_messages: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -554,6 +554,7 @@ def register_memory_recall_tool(
|
|||
timeout_governor=timeout_governor,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
result = await tool.execute(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
|
|
@ -587,12 +588,9 @@ def register_memory_recall_tool(
|
|||
"recall_scope": {"type": "array", "description": "召回范围,例如 profile/facts/preferences/summary/slots"},
|
||||
"max_recent_messages": {"type": "integer", "description": "历史回填窗口大小"}
|
||||
},
|
||||
"required": ["tenant_id", "user_id", "session_id"]
|
||||
"required": []
|
||||
},
|
||||
"example_action_input": {
|
||||
"tenant_id": "default",
|
||||
"user_id": "u_10086",
|
||||
"session_id": "s_abc_001",
|
||||
"recall_scope": ["profile", "facts", "preferences", "summary", "slots"],
|
||||
"max_recent_messages": 8
|
||||
},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,281 @@
|
|||
"""
|
||||
Metadata Discovery Tool for Mid Platform.
|
||||
[AC-MARH-XX] 元数据发现工具,用于查询当前可用的元数据字段及其常见值。
|
||||
|
||||
核心特性:
|
||||
- 列出当前知识库文档中使用的元数据字段
|
||||
- 返回每个字段的常见取值(从现有文档中聚合)
|
||||
- 支持按知识库过滤
|
||||
- 返回字段定义信息(类型、用途说明等)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import Document, MetadataFieldDefinition, MetadataFieldStatus
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mid.tool_registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT_MS = 2000
|
||||
DEFAULT_TOP_VALUES = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFieldDiscoveryConfig:
|
||||
"""Configuration for metadata field discovery tool."""
|
||||
timeout_ms: int = DEFAULT_TIMEOUT_MS
|
||||
top_values_count: int = DEFAULT_TOP_VALUES
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFieldInfo:
|
||||
"""Information about a metadata field."""
|
||||
field_key: str
|
||||
field_type: str = "string"
|
||||
label: str = ""
|
||||
description: str | None = None
|
||||
common_values: list[str] = field(default_factory=list)
|
||||
value_count: int = 0
|
||||
is_filterable: bool = True
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataDiscoveryResult:
|
||||
"""Result of metadata discovery."""
|
||||
success: bool
|
||||
fields: list[MetadataFieldInfo] = field(default_factory=list)
|
||||
total_documents: int = 0
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
|
||||
|
||||
class MetadataDiscoveryTool:
|
||||
"""
|
||||
[AC-MARH-XX] 元数据发现工具。
|
||||
|
||||
用于查询当前知识库文档中使用的元数据字段及其常见值,
|
||||
帮助 AI 了解可用的过滤字段,从而更好地构造搜索请求。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: MetadataFieldDiscoveryConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or MetadataFieldDiscoveryConfig()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str | None = None,
|
||||
include_values: bool = True,
|
||||
top_n: int | None = None,
|
||||
) -> MetadataDiscoveryResult:
|
||||
"""
|
||||
Execute metadata discovery.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
kb_id: Optional knowledge base ID to filter
|
||||
include_values: Whether to include common values (default True)
|
||||
top_n: Number of top values to return per field (default from config)
|
||||
|
||||
Returns:
|
||||
MetadataDiscoveryResult with field information
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
top_n = top_n or self._config.top_values_count
|
||||
|
||||
field_definitions = await self._get_field_definitions(tenant_id)
|
||||
|
||||
document_metadata = await self._get_document_metadata(tenant_id, kb_id)
|
||||
|
||||
total_docs = len(document_metadata)
|
||||
|
||||
field_values: dict[str, Counter] = {}
|
||||
for doc_meta in document_metadata:
|
||||
if not doc_meta:
|
||||
continue
|
||||
for key, value in doc_meta.items():
|
||||
if key not in field_values:
|
||||
field_values[key] = Counter()
|
||||
str_value = str(value) if value is not None else ""
|
||||
field_values[key].update([str_value])
|
||||
|
||||
fields: list[MetadataFieldInfo] = []
|
||||
for field_key, values_counter in field_values.items():
|
||||
field_def = field_definitions.get(field_key)
|
||||
|
||||
common_values = []
|
||||
if include_values:
|
||||
most_common = values_counter.most_common(top_n)
|
||||
common_values = [v for v, _ in most_common if v]
|
||||
|
||||
field_info = MetadataFieldInfo(
|
||||
field_key=field_key,
|
||||
field_type=field_def.type if field_def else "string",
|
||||
label=field_def.label if field_def else field_key,
|
||||
description=field_def.usage_description if field_def else None,
|
||||
common_values=common_values,
|
||||
value_count=len(values_counter),
|
||||
is_filterable=field_def.is_filterable if field_def else True,
|
||||
options=field_def.options if field_def else None,
|
||||
)
|
||||
fields.append(field_info)
|
||||
|
||||
fields.sort(key=lambda f: f.value_count, reverse=True)
|
||||
|
||||
duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataDiscovery] Discovered {len(fields)} fields from {total_docs} documents, "
|
||||
f"duration={duration_ms}ms"
|
||||
)
|
||||
|
||||
return MetadataDiscoveryResult(
|
||||
success=True,
|
||||
fields=fields,
|
||||
total_documents=total_docs,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataDiscovery] Discovery failed: {e}")
|
||||
return MetadataDiscoveryResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def _get_field_definitions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> dict[str, MetadataFieldDefinition]:
|
||||
"""Get field definitions for tenant."""
|
||||
stmt = select(MetadataFieldDefinition).where(
|
||||
MetadataFieldDefinition.tenant_id == tenant_id,
|
||||
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
definitions = result.scalars().all()
|
||||
|
||||
return {d.field_key: d for d in definitions}
|
||||
|
||||
async def _get_document_metadata(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all document metadata for tenant."""
|
||||
stmt = select(Document.doc_metadata).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
)
|
||||
if kb_id:
|
||||
stmt = stmt.where(Document.kb_id == kb_id)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
|
||||
return [row for row in rows if row]
|
||||
|
||||
|
||||
def register_metadata_discovery_tool(
|
||||
registry: "ToolRegistry",
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: MetadataFieldDiscoveryConfig | None = None,
|
||||
) -> None:
|
||||
"""Register metadata discovery tool to registry."""
|
||||
from app.services.mid.tool_registry import ToolType
|
||||
|
||||
cfg = config or MetadataFieldDiscoveryConfig()
|
||||
|
||||
async def metadata_discovery_handler(
|
||||
tenant_id: str = "",
|
||||
kb_id: str | None = None,
|
||||
include_values: bool = True,
|
||||
top_n: int | None = None,
|
||||
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||
) -> dict[str, Any]:
|
||||
"""Metadata discovery tool handler."""
|
||||
tool = MetadataDiscoveryTool(
|
||||
session=session,
|
||||
timeout_governor=timeout_governor,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
result = await tool.execute(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
include_values=include_values,
|
||||
top_n=top_n,
|
||||
)
|
||||
# 将 dataclass 转换为 dict
|
||||
return {
|
||||
"success": result.success,
|
||||
"fields": [
|
||||
{
|
||||
"field_key": f.field_key,
|
||||
"field_type": f.field_type,
|
||||
"label": f.label,
|
||||
"description": f.description,
|
||||
"common_values": f.common_values,
|
||||
"value_count": f.value_count,
|
||||
"is_filterable": f.is_filterable,
|
||||
"options": f.options,
|
||||
}
|
||||
for f in result.fields
|
||||
],
|
||||
"total_documents": result.total_documents,
|
||||
"error": result.error,
|
||||
"duration_ms": result.duration_ms,
|
||||
}
|
||||
|
||||
registry.register(
|
||||
name="list_document_metadata_fields",
|
||||
description="列出当前知识库文档中使用的元数据字段及其常见取值,用于后续的知识库搜索过滤",
|
||||
handler=metadata_discovery_handler,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
auth_required=False,
|
||||
timeout_ms=cfg.timeout_ms,
|
||||
enabled=True,
|
||||
metadata={
|
||||
"when_to_use": "当需要了解知识库中有哪些可用的元数据过滤字段时使用。",
|
||||
"when_not_to_use": "当已知可用的过滤字段,或不需要元数据过滤时不需要调用。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tenant_id": {"type": "string", "description": "租户 ID"},
|
||||
"kb_id": {"type": "string", "description": "知识库 ID(可选,用于限定范围)"},
|
||||
"include_values": {"type": "boolean", "description": "是否包含常见值列表,默认 true"},
|
||||
"top_n": {"type": "integer", "description": "每个字段返回的常见值数量,默认 10"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"example_action_input": {
|
||||
"include_values": True,
|
||||
"top_n": 5,
|
||||
},
|
||||
"result_interpretation": "fields 数组包含每个字段的详细信息;common_values 是该字段在文档中的常见取值;value_count 表示该字段在多少文档中出现。",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[MetadataDiscovery] Tool registered: list_document_metadata_fields")
|
||||
|
|
@ -165,21 +165,56 @@ class MetadataFilterBuilder:
|
|||
) -> list[FilterFieldInfo]:
|
||||
"""
|
||||
[AC-MRS-11] 获取可过滤的字段定义。
|
||||
优先从 Redis 缓存获取,未缓存则从数据库查询并缓存。
|
||||
|
||||
条件:
|
||||
- 状态=生效 (active)
|
||||
- field_roles 包含 resource_filter
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
cached_fields = await cache_service.get_fields(tenant_id)
|
||||
|
||||
if cached_fields is not None:
|
||||
# 缓存命中,直接返回
|
||||
logger.info(
|
||||
f"[AC-MRS-11] Cache hit: Retrieved {len(cached_fields)} fields "
|
||||
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
|
||||
)
|
||||
return [
|
||||
FilterFieldInfo(
|
||||
field_key=f["field_key"],
|
||||
label=f["label"],
|
||||
field_type=f["field_type"],
|
||||
required=f["required"],
|
||||
options=f.get("options"),
|
||||
default_value=f.get("default_value"),
|
||||
is_filterable=f["is_filterable"],
|
||||
)
|
||||
for f in cached_fields
|
||||
]
|
||||
|
||||
# 2. 缓存未命中,从数据库查询
|
||||
logger.info(f"[AC-MRS-11] Cache miss: Querying database for tenant={tenant_id}")
|
||||
db_start = time.time()
|
||||
|
||||
fields = await self._role_provider.get_fields_by_role(
|
||||
tenant_id=tenant_id,
|
||||
role=FieldRole.RESOURCE_FILTER.value,
|
||||
)
|
||||
|
||||
db_time = (time.time() - db_start) * 1000
|
||||
logger.info(
|
||||
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}"
|
||||
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields "
|
||||
f"for tenant={tenant_id} from DB in {db_time:.2f}ms"
|
||||
)
|
||||
|
||||
return [
|
||||
# 3. 转换为 FilterFieldInfo 列表
|
||||
filter_fields = [
|
||||
FilterFieldInfo(
|
||||
field_key=f.field_key,
|
||||
label=f.label,
|
||||
|
|
@ -192,6 +227,29 @@ class MetadataFilterBuilder:
|
|||
for f in fields
|
||||
]
|
||||
|
||||
# 4. 缓存到 Redis
|
||||
cache_data = [
|
||||
{
|
||||
"field_key": f.field_key,
|
||||
"label": f.label,
|
||||
"field_type": f.field_type,
|
||||
"required": f.required,
|
||||
"options": f.options,
|
||||
"default_value": f.default_value,
|
||||
"is_filterable": f.is_filterable,
|
||||
}
|
||||
for f in filter_fields
|
||||
]
|
||||
await cache_service.set_fields(tenant_id, cache_data)
|
||||
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[AC-MRS-11] Total time for tenant={tenant_id}: {total_time:.2f}ms "
|
||||
f"(DB: {db_time:.2f}ms)"
|
||||
)
|
||||
|
||||
return filter_fields
|
||||
|
||||
def _build_field_filter(
|
||||
self,
|
||||
field_info: FilterFieldInfo,
|
||||
|
|
|
|||
|
|
@ -175,7 +175,9 @@ class RoleBasedFieldProvider:
|
|||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"extract_strategies": slot.extract_strategies,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
|
|
@ -217,7 +219,9 @@ class RoleBasedFieldProvider:
|
|||
"slot_key": field.field_key,
|
||||
"type": field.type,
|
||||
"required": field.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": None,
|
||||
"extract_strategies": None,
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": None,
|
||||
"default_value": field.default_value,
|
||||
|
|
|
|||
|
|
@ -53,24 +53,35 @@ class RuntimeContext:
|
|||
|
||||
def to_trace_info(self) -> TraceInfo:
|
||||
"""转换为 TraceInfo。"""
|
||||
return TraceInfo(
|
||||
mode=self.mode,
|
||||
intent=self.intent,
|
||||
request_id=self.request_id,
|
||||
generation_id=self.generation_id,
|
||||
guardrail_triggered=self.guardrail_triggered,
|
||||
guardrail_rule_id=self.guardrail_rule_id,
|
||||
interrupt_consumed=self.interrupt_consumed,
|
||||
kb_tool_called=self.kb_tool_called,
|
||||
kb_hit=self.kb_hit,
|
||||
fallback_reason_code=self.fallback_reason_code,
|
||||
react_iterations=self.react_iterations,
|
||||
timeout_profile=self.timeout_profile,
|
||||
segment_stats=self.segment_stats,
|
||||
metrics_snapshot=self.metrics_snapshot,
|
||||
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls if self.tool_calls else None,
|
||||
)
|
||||
try:
|
||||
return TraceInfo(
|
||||
mode=self.mode,
|
||||
intent=self.intent,
|
||||
request_id=self.request_id,
|
||||
generation_id=self.generation_id,
|
||||
guardrail_triggered=self.guardrail_triggered,
|
||||
guardrail_rule_id=self.guardrail_rule_id,
|
||||
interrupt_consumed=self.interrupt_consumed,
|
||||
kb_tool_called=self.kb_tool_called,
|
||||
kb_hit=self.kb_hit,
|
||||
fallback_reason_code=self.fallback_reason_code,
|
||||
react_iterations=self.react_iterations,
|
||||
timeout_profile=self.timeout_profile,
|
||||
segment_stats=self.segment_stats,
|
||||
metrics_snapshot=self.metrics_snapshot,
|
||||
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls if self.tool_calls else None,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(
|
||||
f"[RuntimeObserver] Failed to create TraceInfo: {e}\n"
|
||||
f"Exception type: {type(e).__name__}\n"
|
||||
f"Context: mode={self.mode}, request_id={self.request_id}, "
|
||||
f"generation_id={self.generation_id}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class RuntimeObserver:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,423 @@
|
|||
"""
|
||||
Scene Slot Bundle Loader Service.
|
||||
[AC-SCENE-SLOT-02] 运行时场景槽位包加载器
|
||||
[AC-SCENE-SLOT-03] 支持缓存层
|
||||
|
||||
职责:
|
||||
1. 根据场景标识加载槽位包配置
|
||||
2. 聚合槽位定义详情
|
||||
3. 计算缺失槽位
|
||||
4. 生成追问提示
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import SceneSlotBundleStatus
|
||||
from app.services.scene_slot_bundle_service import SceneSlotBundleService
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
from app.services.cache.scene_slot_bundle_cache import (
|
||||
CachedSceneSlotBundle,
|
||||
get_scene_slot_bundle_cache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotInfo:
|
||||
"""槽位信息"""
|
||||
slot_key: str
|
||||
type: str
|
||||
required: bool
|
||||
ask_back_prompt: str | None = None
|
||||
validation_rule: str | None = None
|
||||
linked_field_id: str | None = None
|
||||
default_value: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneSlotContext:
|
||||
"""
|
||||
[AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
|
||||
运行时使用的场景槽位包信息
|
||||
"""
|
||||
scene_key: str
|
||||
scene_name: str
|
||||
required_slots: list[SlotInfo] = field(default_factory=list)
|
||||
optional_slots: list[SlotInfo] = field(default_factory=list)
|
||||
slot_priority: list[str] = field(default_factory=list)
|
||||
completion_threshold: float = 1.0
|
||||
ask_back_order: str = "priority"
|
||||
|
||||
def get_all_slot_keys(self) -> list[str]:
|
||||
"""获取所有槽位键名"""
|
||||
return [s.slot_key for s in self.required_slots] + [s.slot_key for s in self.optional_slots]
|
||||
|
||||
def get_required_slot_keys(self) -> list[str]:
|
||||
"""获取必填槽位键名"""
|
||||
return [s.slot_key for s in self.required_slots]
|
||||
|
||||
def get_optional_slot_keys(self) -> list[str]:
|
||||
"""获取可选槽位键名"""
|
||||
return [s.slot_key for s in self.optional_slots]
|
||||
|
||||
def get_missing_slots(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取缺失的必填槽位信息
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
缺失槽位信息列表
|
||||
"""
|
||||
missing = []
|
||||
|
||||
for slot_info in self.required_slots:
|
||||
if slot_info.slot_key not in filled_slots:
|
||||
missing.append({
|
||||
"slot_key": slot_info.slot_key,
|
||||
"type": slot_info.type,
|
||||
"required": True,
|
||||
"ask_back_prompt": slot_info.ask_back_prompt,
|
||||
"validation_rule": slot_info.validation_rule,
|
||||
"linked_field_id": slot_info.linked_field_id,
|
||||
})
|
||||
|
||||
return missing
|
||||
|
||||
def get_ordered_missing_slots(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
按优先级顺序获取缺失的必填槽位
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
按优先级排序的缺失槽位信息列表
|
||||
"""
|
||||
missing = self.get_missing_slots(filled_slots)
|
||||
|
||||
if not missing:
|
||||
return []
|
||||
|
||||
if self.ask_back_order == "required_first":
|
||||
return missing
|
||||
|
||||
if self.ask_back_order == "priority" and self.slot_priority:
|
||||
priority_map = {slot_key: idx for idx, slot_key in enumerate(self.slot_priority)}
|
||||
missing.sort(key=lambda x: priority_map.get(x["slot_key"], 999))
|
||||
|
||||
return missing
|
||||
|
||||
def get_completion_ratio(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> float:
|
||||
"""
|
||||
计算完成比例
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
完成比例 (0.0 - 1.0)
|
||||
"""
|
||||
if not self.required_slots:
|
||||
return 1.0
|
||||
|
||||
filled_count = sum(
|
||||
1 for slot_info in self.required_slots
|
||||
if slot_info.slot_key in filled_slots
|
||||
)
|
||||
|
||||
return filled_count / len(self.required_slots)
|
||||
|
||||
def is_complete(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否完成
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
是否达到完成阈值
|
||||
"""
|
||||
return self.get_completion_ratio(filled_slots) >= self.completion_threshold
|
||||
|
||||
|
||||
class SceneSlotBundleLoader:
|
||||
"""
|
||||
[AC-SCENE-SLOT-02] 场景槽位包加载器
|
||||
[AC-SCENE-SLOT-03] 支持缓存层
|
||||
|
||||
运行时加载场景槽位包配置
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, use_cache: bool = True):
|
||||
self._session = session
|
||||
self._bundle_service = SceneSlotBundleService(session)
|
||||
self._slot_service = SlotDefinitionService(session)
|
||||
self._use_cache = use_cache
|
||||
self._cache = get_scene_slot_bundle_cache() if use_cache else None
|
||||
|
||||
async def load_scene_context(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> SceneSlotContext | None:
|
||||
"""
|
||||
加载场景槽位上下文
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
场景槽位上下文或 None
|
||||
"""
|
||||
# [AC-SCENE-SLOT-03] 尝试从缓存获取
|
||||
if self._use_cache and self._cache:
|
||||
cached = await self._cache.get(tenant_id, scene_key)
|
||||
if cached and cached.status == SceneSlotBundleStatus.ACTIVE.value:
|
||||
logger.debug(
|
||||
f"[AC-SCENE-SLOT-03] Cache hit for scene: {scene_key}"
|
||||
)
|
||||
return await self._build_context_from_cached(tenant_id, cached)
|
||||
|
||||
bundle = await self._bundle_service.get_active_bundle_by_scene(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=scene_key,
|
||||
)
|
||||
|
||||
if not bundle:
|
||||
logger.debug(
|
||||
f"[AC-SCENE-SLOT-02] No active bundle found for scene: {scene_key}"
|
||||
)
|
||||
return None
|
||||
|
||||
# [AC-SCENE-SLOT-03] 写入缓存
|
||||
if self._use_cache and self._cache:
|
||||
cached_bundle = CachedSceneSlotBundle(
|
||||
scene_key=bundle.scene_key,
|
||||
scene_name=bundle.scene_name,
|
||||
description=bundle.description,
|
||||
required_slots=bundle.required_slots,
|
||||
optional_slots=bundle.optional_slots,
|
||||
slot_priority=bundle.slot_priority,
|
||||
completion_threshold=bundle.completion_threshold,
|
||||
ask_back_order=bundle.ask_back_order,
|
||||
status=bundle.status,
|
||||
version=bundle.version,
|
||||
)
|
||||
await self._cache.set(tenant_id, scene_key, cached_bundle)
|
||||
|
||||
return await self._build_context_from_bundle(tenant_id, bundle)
|
||||
|
||||
async def _build_context_from_cached(
|
||||
self,
|
||||
tenant_id: str,
|
||||
cached: CachedSceneSlotBundle,
|
||||
) -> SceneSlotContext:
|
||||
"""从缓存构建场景槽位上下文"""
|
||||
all_slots = await self._slot_service.list_slot_definitions(tenant_id)
|
||||
slot_map = {slot.slot_key: slot for slot in all_slots}
|
||||
|
||||
return self._build_context(cached, slot_map)
|
||||
|
||||
async def _build_context_from_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle: Any,
|
||||
) -> SceneSlotContext:
|
||||
"""从数据库模型构建场景槽位上下文"""
|
||||
all_slots = await self._slot_service.list_slot_definitions(tenant_id)
|
||||
slot_map = {slot.slot_key: slot for slot in all_slots}
|
||||
|
||||
cached = CachedSceneSlotBundle(
|
||||
scene_key=bundle.scene_key,
|
||||
scene_name=bundle.scene_name,
|
||||
description=bundle.description,
|
||||
required_slots=bundle.required_slots,
|
||||
optional_slots=bundle.optional_slots,
|
||||
slot_priority=bundle.slot_priority,
|
||||
completion_threshold=bundle.completion_threshold,
|
||||
ask_back_order=bundle.ask_back_order,
|
||||
status=bundle.status,
|
||||
version=bundle.version,
|
||||
)
|
||||
|
||||
return self._build_context(cached, slot_map)
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
cached: CachedSceneSlotBundle,
|
||||
slot_map: dict[str, Any],
|
||||
) -> SceneSlotContext:
|
||||
"""构建场景槽位上下文"""
|
||||
required_slot_infos = []
|
||||
for slot_key in cached.required_slots:
|
||||
if slot_key in slot_map:
|
||||
slot_def = slot_map[slot_key]
|
||||
required_slot_infos.append(SlotInfo(
|
||||
slot_key=slot_def.slot_key,
|
||||
type=slot_def.type,
|
||||
required=True,
|
||||
ask_back_prompt=slot_def.ask_back_prompt,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None,
|
||||
default_value=slot_def.default_value,
|
||||
))
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-02] Required slot not found: {slot_key}"
|
||||
)
|
||||
|
||||
optional_slot_infos = []
|
||||
for slot_key in cached.optional_slots:
|
||||
if slot_key in slot_map:
|
||||
slot_def = slot_map[slot_key]
|
||||
optional_slot_infos.append(SlotInfo(
|
||||
slot_key=slot_def.slot_key,
|
||||
type=slot_def.type,
|
||||
required=slot_def.required,
|
||||
ask_back_prompt=slot_def.ask_back_prompt,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None,
|
||||
default_value=slot_def.default_value,
|
||||
))
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-02] Optional slot not found: {slot_key}"
|
||||
)
|
||||
|
||||
context = SceneSlotContext(
|
||||
scene_key=cached.scene_key,
|
||||
scene_name=cached.scene_name,
|
||||
required_slots=required_slot_infos,
|
||||
optional_slots=optional_slot_infos,
|
||||
slot_priority=cached.slot_priority or [],
|
||||
completion_threshold=cached.completion_threshold,
|
||||
ask_back_order=cached.ask_back_order,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-02] Loaded scene context: scene={cached.scene_key}, "
|
||||
f"required={len(required_slot_infos)}, optional={len(optional_slot_infos)}, "
|
||||
f"threshold={cached.completion_threshold}"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-SCENE-SLOT-03] 使缓存失效
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if self._cache:
|
||||
return await self._cache.invalidate_on_update(tenant_id, scene_key)
|
||||
return True
|
||||
|
||||
async def get_missing_slots_for_scene(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取场景缺失的必填槽位
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
缺失槽位信息列表
|
||||
"""
|
||||
context = await self.load_scene_context(tenant_id, scene_key)
|
||||
|
||||
if not context:
|
||||
return []
|
||||
|
||||
return context.get_ordered_missing_slots(filled_slots)
|
||||
|
||||
async def generate_ask_back_prompt(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
filled_slots: dict[str, Any],
|
||||
max_slots: int = 2,
|
||||
) -> str | None:
|
||||
"""
|
||||
生成追问提示
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
filled_slots: 已填充的槽位值
|
||||
max_slots: 最多追问的槽位数量
|
||||
|
||||
Returns:
|
||||
追问提示或 None
|
||||
"""
|
||||
missing_slots = await self.get_missing_slots_for_scene(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=scene_key,
|
||||
filled_slots=filled_slots,
|
||||
)
|
||||
|
||||
if not missing_slots:
|
||||
return None
|
||||
|
||||
context = await self.load_scene_context(tenant_id, scene_key)
|
||||
ask_back_order = context.ask_back_order if context else "priority"
|
||||
|
||||
if ask_back_order == "parallel":
|
||||
prompts = []
|
||||
for missing in missing_slots[:max_slots]:
|
||||
if missing.get("ask_back_prompt"):
|
||||
prompts.append(missing["ask_back_prompt"])
|
||||
else:
|
||||
slot_key = missing.get("slot_key", "相关信息")
|
||||
prompts.append(f"您的{slot_key}")
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
elif len(prompts) == 2:
|
||||
return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。"
|
||||
else:
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。"
|
||||
else:
|
||||
first_missing = missing_slots[0]
|
||||
ask_back_prompt = first_missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
return ask_back_prompt
|
||||
|
||||
slot_key = first_missing.get("slot_key", "相关信息")
|
||||
return f"为了更好地为您提供帮助,请告诉我您的{slot_key}。"
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
"""
|
||||
Scene Slot Bundle Metrics Service.
|
||||
[AC-SCENE-SLOT-04] 场景槽位包监控指标
|
||||
|
||||
职责:
|
||||
1. 收集场景槽位包相关的监控指标
|
||||
2. 提供告警检测接口
|
||||
3. 支持指标导出
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneSlotMetricPoint:
|
||||
"""单个指标数据点"""
|
||||
timestamp: datetime
|
||||
tenant_id: str
|
||||
scene_key: str
|
||||
metric_name: str
|
||||
value: float
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneSlotMetricsSummary:
|
||||
"""指标汇总"""
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
missing_slots_triggered: int = 0
|
||||
ask_back_triggered: int = 0
|
||||
scene_not_configured: int = 0
|
||||
slot_not_found: int = 0
|
||||
avg_completion_ratio: float = 0.0
|
||||
|
||||
@property
|
||||
def cache_hit_rate(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
|
||||
class SceneSlotMetricsCollector:
|
||||
"""
|
||||
[AC-SCENE-SLOT-04] 场景槽位包指标收集器
|
||||
|
||||
收集以下指标:
|
||||
- scene_slot_requests_total: 场景槽位请求总数
|
||||
- scene_slot_cache_hits: 缓存命中次数
|
||||
- scene_slot_cache_misses: 缓存未命中次数
|
||||
- scene_slot_missing_triggered: 缺失槽位触发次数
|
||||
- scene_slot_ask_back_triggered: 追问触发次数
|
||||
- scene_slot_not_configured: 场景未配置次数
|
||||
- scene_slot_not_found: 槽位未找到次数
|
||||
- scene_slot_completion_ratio: 槽位完成比例
|
||||
"""
|
||||
|
||||
def __init__(self, max_points: int = 10000):
|
||||
self._max_points = max_points
|
||||
self._points: list[SceneSlotMetricPoint] = []
|
||||
self._counters: dict[str, int] = defaultdict(int)
|
||||
self._lock = threading.Lock()
|
||||
self._start_time = datetime.utcnow()
|
||||
|
||||
def record(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
metric_name: str,
|
||||
value: float = 1.0,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
记录指标
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
metric_name: 指标名称
|
||||
value: 指标值
|
||||
tags: 额外标签
|
||||
"""
|
||||
point = SceneSlotMetricPoint(
|
||||
timestamp=datetime.utcnow(),
|
||||
tenant_id=tenant_id,
|
||||
scene_key=scene_key,
|
||||
metric_name=metric_name,
|
||||
value=value,
|
||||
tags=tags or {},
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._points.append(point)
|
||||
if len(self._points) > self._max_points:
|
||||
self._points = self._points[-self._max_points:]
|
||||
|
||||
counter_key = f"{tenant_id}:{scene_key}:{metric_name}"
|
||||
self._counters[counter_key] += 1
|
||||
|
||||
def record_cache_hit(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录缓存命中"""
|
||||
self.record(tenant_id, scene_key, "cache_hit")
|
||||
|
||||
def record_cache_miss(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录缓存未命中"""
|
||||
self.record(tenant_id, scene_key, "cache_miss")
|
||||
|
||||
def record_missing_slots(self, tenant_id: str, scene_key: str, count: int = 1) -> None:
|
||||
"""记录缺失槽位触发"""
|
||||
self.record(tenant_id, scene_key, "missing_slots_triggered", float(count))
|
||||
|
||||
def record_ask_back(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录追问触发"""
|
||||
self.record(tenant_id, scene_key, "ask_back_triggered")
|
||||
|
||||
def record_scene_not_configured(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录场景未配置"""
|
||||
self.record(tenant_id, scene_key, "scene_not_configured")
|
||||
|
||||
def record_slot_not_found(self, tenant_id: str, scene_key: str, slot_key: str) -> None:
|
||||
"""记录槽位未找到"""
|
||||
self.record(tenant_id, scene_key, "slot_not_found", tags={"slot_key": slot_key})
|
||||
|
||||
def record_completion_ratio(self, tenant_id: str, scene_key: str, ratio: float) -> None:
|
||||
"""记录槽位完成比例"""
|
||||
self.record(tenant_id, scene_key, "completion_ratio", ratio)
|
||||
|
||||
def get_summary(self, tenant_id: str | None = None) -> SceneSlotMetricsSummary:
|
||||
"""
|
||||
获取指标汇总
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID(可选,为 None 时返回所有租户的汇总)
|
||||
|
||||
Returns:
|
||||
指标汇总
|
||||
"""
|
||||
summary = SceneSlotMetricsSummary()
|
||||
|
||||
with self._lock:
|
||||
points = self._points.copy()
|
||||
|
||||
for point in points:
|
||||
if tenant_id and point.tenant_id != tenant_id:
|
||||
continue
|
||||
|
||||
summary.total_requests += 1
|
||||
|
||||
if point.metric_name == "cache_hit":
|
||||
summary.cache_hits += 1
|
||||
elif point.metric_name == "cache_miss":
|
||||
summary.cache_misses += 1
|
||||
elif point.metric_name == "missing_slots_triggered":
|
||||
summary.missing_slots_triggered += int(point.value)
|
||||
elif point.metric_name == "ask_back_triggered":
|
||||
summary.ask_back_triggered += 1
|
||||
elif point.metric_name == "scene_not_configured":
|
||||
summary.scene_not_configured += 1
|
||||
elif point.metric_name == "slot_not_found":
|
||||
summary.slot_not_found += 1
|
||||
elif point.metric_name == "completion_ratio":
|
||||
summary.avg_completion_ratio = (
|
||||
summary.avg_completion_ratio * summary.total_requests + point.value
|
||||
) / (summary.total_requests + 1) if summary.total_requests > 0 else point.value
|
||||
|
||||
return summary
|
||||
|
||||
def get_metrics_by_scene(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取特定场景的指标
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
指标字典
|
||||
"""
|
||||
metrics = {
|
||||
"requests": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"missing_slots_triggered": 0,
|
||||
"ask_back_triggered": 0,
|
||||
"slot_not_found": 0,
|
||||
"avg_completion_ratio": 0.0,
|
||||
}
|
||||
|
||||
with self._lock:
|
||||
points = [p for p in self._points if p.tenant_id == tenant_id and p.scene_key == scene_key]
|
||||
|
||||
if not points:
|
||||
return metrics
|
||||
|
||||
completion_ratios = []
|
||||
|
||||
for point in points:
|
||||
metrics["requests"] += 1
|
||||
if point.metric_name in metrics:
|
||||
if point.metric_name == "completion_ratio":
|
||||
completion_ratios.append(point.value)
|
||||
else:
|
||||
metrics[point.metric_name] += int(point.value)
|
||||
|
||||
if completion_ratios:
|
||||
metrics["avg_completion_ratio"] = sum(completion_ratios) / len(completion_ratios)
|
||||
|
||||
return metrics
|
||||
|
||||
def check_alerts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
thresholds: dict[str, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
检查告警条件
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
thresholds: 告警阈值配置
|
||||
|
||||
Returns:
|
||||
告警列表
|
||||
"""
|
||||
default_thresholds = {
|
||||
"cache_hit_rate_low": 0.5,
|
||||
"missing_slots_rate_high": 0.3,
|
||||
"scene_not_configured_rate_high": 0.1,
|
||||
}
|
||||
|
||||
effective_thresholds = {**default_thresholds, **(thresholds or {})}
|
||||
|
||||
metrics = self.get_metrics_by_scene(tenant_id, scene_key)
|
||||
alerts = []
|
||||
|
||||
if metrics["requests"] > 0:
|
||||
cache_hit_rate = metrics["cache_hits"] / metrics["requests"]
|
||||
if cache_hit_rate < effective_thresholds["cache_hit_rate_low"]:
|
||||
alerts.append({
|
||||
"alert_type": "cache_hit_rate_low",
|
||||
"severity": "warning",
|
||||
"message": f"场景 {scene_key} 的缓存命中率 ({cache_hit_rate:.2%}) 低于阈值 ({effective_thresholds['cache_hit_rate_low']:.0%})",
|
||||
"current_value": cache_hit_rate,
|
||||
"threshold": effective_thresholds["cache_hit_rate_low"],
|
||||
"suggestion": "检查场景槽位包配置是否频繁变更,或增加缓存 TTL",
|
||||
})
|
||||
|
||||
missing_slots_rate = metrics["missing_slots_triggered"] / metrics["requests"]
|
||||
if missing_slots_rate > effective_thresholds["missing_slots_rate_high"]:
|
||||
alerts.append({
|
||||
"alert_type": "missing_slots_rate_high",
|
||||
"severity": "warning",
|
||||
"message": f"场景 {scene_key} 的缺失槽位触发率 ({missing_slots_rate:.2%}) 高于阈值 ({effective_thresholds['missing_slots_rate_high']:.0%})",
|
||||
"current_value": missing_slots_rate,
|
||||
"threshold": effective_thresholds["missing_slots_rate_high"],
|
||||
"suggestion": "检查槽位配置是否合理,或优化槽位提取策略",
|
||||
})
|
||||
|
||||
scene_not_configured_rate = metrics.get("scene_not_configured", 0) / metrics["requests"]
|
||||
if scene_not_configured_rate > effective_thresholds["scene_not_configured_rate_high"]:
|
||||
alerts.append({
|
||||
"alert_type": "scene_not_configured_rate_high",
|
||||
"severity": "error",
|
||||
"message": f"场景 {scene_key} 未配置率 ({scene_not_configured_rate:.2%}) 高于阈值 ({effective_thresholds['scene_not_configured_rate_high']:.0%})",
|
||||
"current_value": scene_not_configured_rate,
|
||||
"threshold": effective_thresholds["scene_not_configured_rate_high"],
|
||||
"suggestion": "请为该场景创建场景槽位包配置",
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
def export_metrics(self, tenant_id: str | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
导出指标数据
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID(可选)
|
||||
|
||||
Returns:
|
||||
指标数据字典
|
||||
"""
|
||||
with self._lock:
|
||||
points = [
|
||||
{
|
||||
"timestamp": p.timestamp.isoformat(),
|
||||
"tenant_id": p.tenant_id,
|
||||
"scene_key": p.scene_key,
|
||||
"metric_name": p.metric_name,
|
||||
"value": p.value,
|
||||
"tags": p.tags,
|
||||
}
|
||||
for p in self._points
|
||||
if tenant_id is None or p.tenant_id == tenant_id
|
||||
]
|
||||
|
||||
return {
|
||||
"start_time": self._start_time.isoformat(),
|
||||
"end_time": datetime.utcnow().isoformat(),
|
||||
"total_points": len(points),
|
||||
"points": points,
|
||||
"summary": self.get_summary(tenant_id).__dict__,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置指标收集器"""
|
||||
with self._lock:
|
||||
self._points = []
|
||||
self._counters = defaultdict(int)
|
||||
self._start_time = datetime.utcnow()
|
||||
|
||||
|
||||
_metrics_collector: SceneSlotMetricsCollector | None = None
|
||||
|
||||
|
||||
def get_scene_slot_metrics_collector() -> SceneSlotMetricsCollector:
|
||||
"""获取场景槽位指标收集器实例"""
|
||||
global _metrics_collector
|
||||
if _metrics_collector is None:
|
||||
_metrics_collector = SceneSlotMetricsCollector()
|
||||
return _metrics_collector
|
||||
|
|
@ -0,0 +1,500 @@
|
|||
"""
|
||||
Slot Backfill Service.
|
||||
槽位回填服务 - 处理槽位值的提取、校验、确认、写回
|
||||
|
||||
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认
|
||||
|
||||
职责:
|
||||
1. 从用户回复提取候选槽位值
|
||||
2. 调用 SlotManager 校验并归一化
|
||||
3. 校验失败返回 ask_back_prompt 二次追问
|
||||
4. 校验通过写入状态并标记 source/confidence
|
||||
5. 对低置信度值增加确认话术
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_manager import SlotManager, SlotWriteResult
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator, SlotState
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
ExtractContext,
|
||||
SlotStrategyExecutor,
|
||||
StrategyChainResult,
|
||||
)
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackfillStatus(str, Enum):
|
||||
"""回填状态"""
|
||||
SUCCESS = "success"
|
||||
VALIDATION_FAILED = "validation_failed"
|
||||
EXTRACTION_FAILED = "extraction_failed"
|
||||
NEEDS_CONFIRMATION = "needs_confirmation"
|
||||
NO_CANDIDATES = "no_candidates"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackfillResult:
|
||||
"""
|
||||
回填结果
|
||||
|
||||
Attributes:
|
||||
status: 回填状态
|
||||
slot_key: 槽位键名
|
||||
value: 最终值(校验通过后)
|
||||
normalized_value: 归一化后的值
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
error_message: 错误信息
|
||||
ask_back_prompt: 追问提示
|
||||
confirmation_prompt: 确认提示(低置信度时)
|
||||
updated_state: 更新后的槽位状态
|
||||
"""
|
||||
status: BackfillStatus
|
||||
slot_key: str | None = None
|
||||
value: Any = None
|
||||
normalized_value: Any = None
|
||||
source: str = "unknown"
|
||||
confidence: float = 0.0
|
||||
error_message: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
confirmation_prompt: str | None = None
|
||||
updated_state: SlotState | None = None
|
||||
|
||||
def is_success(self) -> bool:
|
||||
return self.status == BackfillStatus.SUCCESS
|
||||
|
||||
def needs_ask_back(self) -> bool:
|
||||
return self.status in (
|
||||
BackfillStatus.VALIDATION_FAILED,
|
||||
BackfillStatus.EXTRACTION_FAILED,
|
||||
)
|
||||
|
||||
def needs_confirmation(self) -> bool:
|
||||
return self.status == BackfillStatus.NEEDS_CONFIRMATION
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"slot_key": self.slot_key,
|
||||
"value": self.value,
|
||||
"normalized_value": self.normalized_value,
|
||||
"source": self.source,
|
||||
"confidence": self.confidence,
|
||||
"error_message": self.error_message,
|
||||
"ask_back_prompt": self.ask_back_prompt,
|
||||
"confirmation_prompt": self.confirmation_prompt,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchBackfillResult:
|
||||
"""批量回填结果"""
|
||||
results: list[BackfillResult] = field(default_factory=list)
|
||||
success_count: int = 0
|
||||
failed_count: int = 0
|
||||
confirmation_needed_count: int = 0
|
||||
|
||||
def add_result(self, result: BackfillResult) -> None:
|
||||
self.results.append(result)
|
||||
if result.is_success():
|
||||
self.success_count += 1
|
||||
elif result.needs_confirmation():
|
||||
self.confirmation_needed_count += 1
|
||||
else:
|
||||
self.failed_count += 1
|
||||
|
||||
def get_ask_back_prompts(self) -> list[str]:
|
||||
"""获取所有追问提示"""
|
||||
return [
|
||||
r.ask_back_prompt
|
||||
for r in self.results
|
||||
if r.ask_back_prompt
|
||||
]
|
||||
|
||||
def get_confirmation_prompts(self) -> list[str]:
|
||||
"""获取所有确认提示"""
|
||||
return [
|
||||
r.confirmation_prompt
|
||||
for r in self.results
|
||||
if r.confirmation_prompt
|
||||
]
|
||||
|
||||
|
||||
class SlotBackfillService:
|
||||
"""
|
||||
[AC-MRS-SLOT-BACKFILL-01] 槽位回填服务
|
||||
|
||||
处理槽位值的提取、校验、确认、写回流程:
|
||||
1. 从用户回复提取候选槽位值
|
||||
2. SlotManager 校验并归一化
|
||||
3. 校验失败返回 ask_back_prompt 二次追问
|
||||
4. 校验通过写入状态并标记 source/confidence
|
||||
5. 对低置信度值增加确认话术
|
||||
"""
|
||||
|
||||
CONFIDENCE_THRESHOLD_LOW = 0.5
|
||||
CONFIDENCE_THRESHOLD_HIGH = 0.8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
slot_manager: SlotManager | None = None,
|
||||
strategy_executor: SlotStrategyExecutor | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._slot_manager = slot_manager or SlotManager(session=session)
|
||||
self._strategy_executor = strategy_executor or SlotStrategyExecutor()
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._state_aggregator: SlotStateAggregator | None = None
|
||||
|
||||
async def _get_state_aggregator(self) -> SlotStateAggregator:
|
||||
"""获取状态聚合器"""
|
||||
if self._state_aggregator is None:
|
||||
self._state_aggregator = SlotStateAggregator(
|
||||
session=self._session,
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
return self._state_aggregator
|
||||
|
||||
async def backfill_single_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
candidate_value: Any,
|
||||
source: str = "user_confirmed",
|
||||
confidence: float = 1.0,
|
||||
strategies: list[str] | None = None,
|
||||
) -> BackfillResult:
|
||||
"""
|
||||
回填单个槽位
|
||||
|
||||
执行流程:
|
||||
1. 如果有提取策略,执行提取
|
||||
2. 校验候选值
|
||||
3. 根据校验结果决定下一步
|
||||
4. 写入状态
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
candidate_value: 候选值
|
||||
source: 值来源
|
||||
confidence: 初始置信度
|
||||
strategies: 提取策略链(可选)
|
||||
|
||||
Returns:
|
||||
BackfillResult: 回填结果
|
||||
"""
|
||||
final_value = candidate_value
|
||||
final_source = source
|
||||
final_confidence = confidence
|
||||
|
||||
if strategies:
|
||||
extracted_result = await self._extract_value(
|
||||
slot_key=slot_key,
|
||||
user_input=str(candidate_value),
|
||||
strategies=strategies,
|
||||
)
|
||||
|
||||
if extracted_result.success:
|
||||
final_value = extracted_result.final_value
|
||||
final_source = extracted_result.final_strategy or source
|
||||
final_confidence = self._get_confidence_for_strategy(final_source)
|
||||
else:
|
||||
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.EXTRACTION_FAILED,
|
||||
slot_key=slot_key,
|
||||
value=candidate_value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
error_message=extracted_result.steps[-1].failure_reason if extracted_result.steps else "提取失败",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
write_result = await self._slot_manager.write_slot(
|
||||
tenant_id=self._tenant_id,
|
||||
slot_key=slot_key,
|
||||
value=final_value,
|
||||
source=SlotSource(final_source) if final_source in [s.value for s in SlotSource] else SlotSource.USER_CONFIRMED,
|
||||
confidence=final_confidence,
|
||||
)
|
||||
|
||||
if not write_result.success:
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.VALIDATION_FAILED,
|
||||
slot_key=slot_key,
|
||||
value=final_value,
|
||||
source=final_source,
|
||||
confidence=final_confidence,
|
||||
error_message=write_result.error.error_message if write_result.error else "校验失败",
|
||||
ask_back_prompt=write_result.ask_back_prompt,
|
||||
)
|
||||
|
||||
normalized_value = write_result.value
|
||||
updated_state = None
|
||||
|
||||
if self._session_id:
|
||||
aggregator = await self._get_state_aggregator()
|
||||
updated_state = await aggregator.update_slot(
|
||||
slot_key=slot_key,
|
||||
value=normalized_value,
|
||||
source=final_source,
|
||||
confidence=final_confidence,
|
||||
)
|
||||
|
||||
result_status = BackfillStatus.SUCCESS
|
||||
confirmation_prompt = None
|
||||
|
||||
if final_confidence < self.CONFIDENCE_THRESHOLD_LOW:
|
||||
result_status = BackfillStatus.NEEDS_CONFIRMATION
|
||||
confirmation_prompt = self._generate_confirmation_prompt(
|
||||
slot_key, normalized_value
|
||||
)
|
||||
|
||||
return BackfillResult(
|
||||
status=result_status,
|
||||
slot_key=slot_key,
|
||||
value=final_value,
|
||||
normalized_value=normalized_value,
|
||||
source=final_source,
|
||||
confidence=final_confidence,
|
||||
confirmation_prompt=confirmation_prompt,
|
||||
updated_state=updated_state,
|
||||
)
|
||||
|
||||
async def backfill_multiple_slots(
|
||||
self,
|
||||
candidates: dict[str, Any],
|
||||
source: str = "user_confirmed",
|
||||
confidence: float = 1.0,
|
||||
) -> BatchBackfillResult:
|
||||
"""
|
||||
批量回填槽位
|
||||
|
||||
Args:
|
||||
candidates: 候选值字典 {slot_key: value}
|
||||
source: 值来源
|
||||
confidence: 初始置信度
|
||||
|
||||
Returns:
|
||||
BatchBackfillResult: 批量回填结果
|
||||
"""
|
||||
batch_result = BatchBackfillResult()
|
||||
|
||||
for slot_key, value in candidates.items():
|
||||
result = await self.backfill_single_slot(
|
||||
slot_key=slot_key,
|
||||
candidate_value=value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
batch_result.add_result(result)
|
||||
|
||||
return batch_result
|
||||
|
||||
async def backfill_from_user_response(
|
||||
self,
|
||||
user_response: str,
|
||||
expected_slots: list[str],
|
||||
strategies: list[str] | None = None,
|
||||
) -> BatchBackfillResult:
|
||||
"""
|
||||
从用户回复中提取并回填槽位
|
||||
|
||||
Args:
|
||||
user_response: 用户回复文本
|
||||
expected_slots: 期望提取的槽位列表
|
||||
strategies: 提取策略链
|
||||
|
||||
Returns:
|
||||
BatchBackfillResult: 批量回填结果
|
||||
"""
|
||||
batch_result = BatchBackfillResult()
|
||||
|
||||
for slot_key in expected_slots:
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
|
||||
if not slot_def:
|
||||
continue
|
||||
|
||||
extract_strategies = strategies or ["rule", "llm"]
|
||||
|
||||
extracted_result = await self._extract_value(
|
||||
slot_key=slot_key,
|
||||
user_input=user_response,
|
||||
strategies=extract_strategies,
|
||||
slot_type=slot_def.type,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
)
|
||||
|
||||
if not extracted_result.success:
|
||||
ask_back_prompt = slot_def.ask_back_prompt or f"请提供{slot_key}信息"
|
||||
batch_result.add_result(BackfillResult(
|
||||
status=BackfillStatus.EXTRACTION_FAILED,
|
||||
slot_key=slot_key,
|
||||
error_message="无法从回复中提取",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
))
|
||||
continue
|
||||
|
||||
source = self._get_source_for_strategy(extracted_result.final_strategy)
|
||||
confidence = self._get_confidence_for_strategy(source)
|
||||
|
||||
result = await self.backfill_single_slot(
|
||||
slot_key=slot_key,
|
||||
candidate_value=extracted_result.final_value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
batch_result.add_result(result)
|
||||
|
||||
return batch_result
|
||||
|
||||
async def _extract_value(
|
||||
self,
|
||||
slot_key: str,
|
||||
user_input: str,
|
||||
strategies: list[str],
|
||||
slot_type: str = "string",
|
||||
validation_rule: str | None = None,
|
||||
) -> StrategyChainResult:
|
||||
"""
|
||||
执行槽位值提取
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
user_input: 用户输入
|
||||
strategies: 提取策略链
|
||||
slot_type: 槽位类型
|
||||
validation_rule: 校验规则
|
||||
|
||||
Returns:
|
||||
StrategyChainResult: 提取结果
|
||||
"""
|
||||
context = ExtractContext(
|
||||
tenant_id=self._tenant_id,
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
slot_type=slot_type,
|
||||
validation_rule=validation_rule,
|
||||
)
|
||||
|
||||
return await self._strategy_executor.execute_chain(
|
||||
strategies=strategies,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _get_source_for_strategy(self, strategy: str | None) -> str:
|
||||
"""根据策略获取来源"""
|
||||
strategy_source_map = {
|
||||
"rule": SlotSource.RULE_EXTRACTED.value,
|
||||
"llm": SlotSource.LLM_INFERRED.value,
|
||||
"user_input": SlotSource.USER_CONFIRMED.value,
|
||||
}
|
||||
return strategy_source_map.get(strategy or "", "unknown")
|
||||
|
||||
def _get_confidence_for_strategy(self, source: str) -> float:
|
||||
"""根据来源获取置信度"""
|
||||
confidence_map = {
|
||||
SlotSource.USER_CONFIRMED.value: 1.0,
|
||||
SlotSource.RULE_EXTRACTED.value: 0.9,
|
||||
SlotSource.LLM_INFERRED.value: 0.7,
|
||||
"context": 0.5,
|
||||
SlotSource.DEFAULT.value: 0.3,
|
||||
}
|
||||
return confidence_map.get(source, 0.5)
|
||||
|
||||
def _generate_confirmation_prompt(
|
||||
self,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
) -> str:
|
||||
"""生成确认提示"""
|
||||
return f"我理解您说的是「{value}」,对吗?"
|
||||
|
||||
async def confirm_low_confidence_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
confirmed: bool,
|
||||
) -> BackfillResult:
|
||||
"""
|
||||
确认低置信度槽位
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
confirmed: 用户是否确认
|
||||
|
||||
Returns:
|
||||
BackfillResult: 确认结果
|
||||
"""
|
||||
if not self._session_id:
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
slot_key=slot_key,
|
||||
)
|
||||
|
||||
aggregator = await self._get_state_aggregator()
|
||||
|
||||
if confirmed:
|
||||
updated_state = await aggregator.update_slot(
|
||||
slot_key=slot_key,
|
||||
source=SlotSource.USER_CONFIRMED.value,
|
||||
confidence=1.0,
|
||||
)
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
slot_key=slot_key,
|
||||
source=SlotSource.USER_CONFIRMED.value,
|
||||
confidence=1.0,
|
||||
updated_state=updated_state,
|
||||
)
|
||||
else:
|
||||
await aggregator.clear_slot(slot_key)
|
||||
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.VALIDATION_FAILED,
|
||||
slot_key=slot_key,
|
||||
ask_back_prompt=ask_back_prompt or f"请重新提供{slot_key}信息",
|
||||
)
|
||||
|
||||
|
||||
def create_slot_backfill_service(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
) -> SlotBackfillService:
|
||||
"""
|
||||
创建槽位回填服务实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
SlotBackfillService: 槽位回填服务实例
|
||||
"""
|
||||
return SlotBackfillService(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
"""
|
||||
Slot Extraction Integration Service.
|
||||
槽位提取集成服务 - 将自动提取能力接入主链路
|
||||
|
||||
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成
|
||||
|
||||
职责:
|
||||
1. 接入点:memory_recall 之后、KB 检索之前
|
||||
2. 执行策略链:rule -> llm -> user_input
|
||||
3. 抽取结果统一走 SlotManager 校验
|
||||
4. 提供 trace:extracted_slots、validation_pass/fail、ask_back_triggered
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_backfill_service import (
|
||||
BackfillResult,
|
||||
BackfillStatus,
|
||||
SlotBackfillService,
|
||||
)
|
||||
from app.services.mid.slot_manager import SlotManager
|
||||
from app.services.mid.slot_state_aggregator import SlotState, SlotStateAggregator
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
ExtractContext,
|
||||
SlotStrategyExecutor,
|
||||
StrategyChainResult,
|
||||
)
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionTrace:
|
||||
"""
|
||||
提取追踪信息
|
||||
|
||||
Attributes:
|
||||
slot_key: 槽位键名
|
||||
strategy: 使用的策略
|
||||
extracted_value: 提取的值
|
||||
validation_passed: 校验是否通过
|
||||
final_value: 最终值(校验后)
|
||||
execution_time_ms: 执行时间
|
||||
failure_reason: 失败原因
|
||||
"""
|
||||
slot_key: str
|
||||
strategy: str | None = None
|
||||
extracted_value: Any = None
|
||||
validation_passed: bool = False
|
||||
final_value: Any = None
|
||||
execution_time_ms: float = 0.0
|
||||
failure_reason: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"slot_key": self.slot_key,
|
||||
"strategy": self.strategy,
|
||||
"extracted_value": self.extracted_value,
|
||||
"validation_passed": self.validation_passed,
|
||||
"final_value": self.final_value,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"failure_reason": self.failure_reason,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""
|
||||
提取结果
|
||||
|
||||
Attributes:
|
||||
success: 是否成功
|
||||
extracted_slots: 成功提取的槽位
|
||||
failed_slots: 提取失败的槽位
|
||||
traces: 提取追踪信息列表
|
||||
total_execution_time_ms: 总执行时间
|
||||
ask_back_triggered: 是否触发追问
|
||||
ask_back_prompts: 追问提示列表
|
||||
"""
|
||||
success: bool = False
|
||||
extracted_slots: dict[str, Any] = field(default_factory=dict)
|
||||
failed_slots: list[str] = field(default_factory=list)
|
||||
traces: list[ExtractionTrace] = field(default_factory=list)
|
||||
total_execution_time_ms: float = 0.0
|
||||
ask_back_triggered: bool = False
|
||||
ask_back_prompts: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"success": self.success,
|
||||
"extracted_slots": self.extracted_slots,
|
||||
"failed_slots": self.failed_slots,
|
||||
"traces": [t.to_dict() for t in self.traces],
|
||||
"total_execution_time_ms": self.total_execution_time_ms,
|
||||
"ask_back_triggered": self.ask_back_triggered,
|
||||
"ask_back_prompts": self.ask_back_prompts,
|
||||
}
|
||||
|
||||
|
||||
class SlotExtractionIntegration:
|
||||
"""
|
||||
[AC-MRS-SLOT-EXTRACT-01] 槽位提取集成服务
|
||||
|
||||
将自动提取能力接入主链路:
|
||||
- 接入点:memory_recall 之后、KB 检索之前
|
||||
- 执行策略链:rule -> llm -> user_input
|
||||
- 抽取结果统一走 SlotManager 校验
|
||||
- 提供 trace
|
||||
"""
|
||||
|
||||
DEFAULT_STRATEGIES = ["rule", "llm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
slot_manager: SlotManager | None = None,
|
||||
strategy_executor: SlotStrategyExecutor | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._slot_manager = slot_manager or SlotManager(session=session)
|
||||
self._strategy_executor = strategy_executor or SlotStrategyExecutor()
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._backfill_service: SlotBackfillService | None = None
|
||||
|
||||
async def _get_backfill_service(self) -> SlotBackfillService:
|
||||
"""获取回填服务"""
|
||||
if self._backfill_service is None:
|
||||
self._backfill_service = SlotBackfillService(
|
||||
session=self._session,
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
slot_manager=self._slot_manager,
|
||||
strategy_executor=self._strategy_executor,
|
||||
)
|
||||
return self._backfill_service
|
||||
|
||||
async def extract_and_fill(
|
||||
self,
|
||||
user_input: str,
|
||||
target_slots: list[str] | None = None,
|
||||
strategies: list[str] | None = None,
|
||||
slot_state: SlotState | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
执行提取并填充槽位
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
target_slots: 目标槽位列表(为空则提取所有必填槽位)
|
||||
strategies: 提取策略链(默认 rule -> llm)
|
||||
slot_state: 当前槽位状态(用于识别缺失槽位)
|
||||
|
||||
Returns:
|
||||
ExtractionResult: 提取结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
strategies = strategies or self.DEFAULT_STRATEGIES
|
||||
|
||||
if target_slots is None:
|
||||
target_slots = await self._get_missing_required_slots(slot_state)
|
||||
|
||||
if not target_slots:
|
||||
return ExtractionResult(success=True)
|
||||
|
||||
result = ExtractionResult()
|
||||
|
||||
for slot_key in target_slots:
|
||||
trace = await self._extract_single_slot(
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
strategies=strategies,
|
||||
)
|
||||
result.traces.append(trace)
|
||||
|
||||
if trace.validation_passed and trace.final_value is not None:
|
||||
result.extracted_slots[slot_key] = trace.final_value
|
||||
else:
|
||||
result.failed_slots.append(slot_key)
|
||||
if trace.failure_reason:
|
||||
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
if ask_back_prompt:
|
||||
result.ask_back_prompts.append(ask_back_prompt)
|
||||
|
||||
result.total_execution_time_ms = (time.time() - start_time) * 1000
|
||||
result.success = len(result.extracted_slots) > 0
|
||||
result.ask_back_triggered = len(result.ask_back_prompts) > 0
|
||||
|
||||
if result.extracted_slots and self._session_id:
|
||||
await self._save_extracted_slots(result.extracted_slots)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-EXTRACT-01] Extraction completed: "
|
||||
f"tenant={self._tenant_id}, extracted={len(result.extracted_slots)}, "
|
||||
f"failed={len(result.failed_slots)}, time_ms={result.total_execution_time_ms:.2f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _extract_single_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
user_input: str,
|
||||
strategies: list[str],
|
||||
) -> ExtractionTrace:
|
||||
"""提取单个槽位"""
|
||||
start_time = time.time()
|
||||
trace = ExtractionTrace(slot_key=slot_key)
|
||||
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
|
||||
if not slot_def:
|
||||
trace.failure_reason = "Slot definition not found"
|
||||
trace.execution_time_ms = (time.time() - start_time) * 1000
|
||||
return trace
|
||||
|
||||
context = ExtractContext(
|
||||
tenant_id=self._tenant_id,
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
slot_type=slot_def.type,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
|
||||
chain_result = await self._strategy_executor.execute_chain(
|
||||
strategies=strategies,
|
||||
context=context,
|
||||
ask_back_prompt=slot_def.ask_back_prompt,
|
||||
)
|
||||
|
||||
trace.strategy = chain_result.final_strategy
|
||||
trace.extracted_value = chain_result.final_value
|
||||
trace.execution_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not chain_result.success:
|
||||
trace.failure_reason = "Extraction failed"
|
||||
if chain_result.steps:
|
||||
last_step = chain_result.steps[-1]
|
||||
trace.failure_reason = last_step.failure_reason
|
||||
return trace
|
||||
|
||||
backfill_service = await self._get_backfill_service()
|
||||
source = self._get_source_for_strategy(chain_result.final_strategy)
|
||||
|
||||
backfill_result = await backfill_service.backfill_single_slot(
|
||||
slot_key=slot_key,
|
||||
candidate_value=chain_result.final_value,
|
||||
source=source,
|
||||
confidence=self._get_confidence_for_source(source),
|
||||
)
|
||||
|
||||
trace.validation_passed = backfill_result.is_success()
|
||||
trace.final_value = backfill_result.normalized_value
|
||||
|
||||
if not backfill_result.is_success():
|
||||
trace.failure_reason = backfill_result.error_message or "Validation failed"
|
||||
|
||||
return trace
|
||||
|
||||
async def _get_missing_required_slots(
|
||||
self,
|
||||
slot_state: SlotState | None,
|
||||
) -> list[str]:
|
||||
"""获取缺失的必填槽位"""
|
||||
if slot_state and slot_state.missing_required_slots:
|
||||
return [
|
||||
s.get("slot_key")
|
||||
for s in slot_state.missing_required_slots
|
||||
if s.get("slot_key")
|
||||
]
|
||||
|
||||
required_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id,
|
||||
required=True,
|
||||
)
|
||||
|
||||
return [d.slot_key for d in required_defs]
|
||||
|
||||
async def _save_extracted_slots(
|
||||
self,
|
||||
extracted_slots: dict[str, Any],
|
||||
) -> None:
|
||||
"""保存提取的槽位到缓存"""
|
||||
if not self._session_id:
|
||||
return
|
||||
|
||||
aggregator = SlotStateAggregator(
|
||||
session=self._session,
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
|
||||
for slot_key, value in extracted_slots.items():
|
||||
await aggregator.update_slot(
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
source=SlotSource.RULE_EXTRACTED.value,
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
def _get_source_for_strategy(self, strategy: str | None) -> str:
|
||||
"""根据策略获取来源"""
|
||||
strategy_source_map = {
|
||||
"rule": SlotSource.RULE_EXTRACTED.value,
|
||||
"llm": SlotSource.LLM_INFERRED.value,
|
||||
"user_input": SlotSource.USER_CONFIRMED.value,
|
||||
}
|
||||
return strategy_source_map.get(strategy or "", "unknown")
|
||||
|
||||
def _get_confidence_for_source(self, source: str) -> float:
|
||||
"""根据来源获取置信度"""
|
||||
confidence_map = {
|
||||
SlotSource.USER_CONFIRMED.value: 1.0,
|
||||
SlotSource.RULE_EXTRACTED.value: 0.9,
|
||||
SlotSource.LLM_INFERRED.value: 0.7,
|
||||
}
|
||||
return confidence_map.get(source, 0.5)
|
||||
|
||||
|
||||
async def integrate_slot_extraction(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
user_input: str,
|
||||
slot_state: SlotState | None = None,
|
||||
strategies: list[str] | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
便捷函数:集成槽位提取
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
user_input: 用户输入
|
||||
slot_state: 当前槽位状态
|
||||
strategies: 提取策略链
|
||||
|
||||
Returns:
|
||||
ExtractionResult: 提取结果
|
||||
"""
|
||||
integration = SlotExtractionIntegration(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return await integration.extract_and_fill(
|
||||
user_input=user_input,
|
||||
slot_state=slot_state,
|
||||
strategies=strategies,
|
||||
)
|
||||
|
|
@ -0,0 +1,379 @@
|
|||
"""
|
||||
Slot Manager Service.
|
||||
槽位管理服务 - 统一槽位写入入口,集成校验逻辑
|
||||
|
||||
职责:
|
||||
1. 在槽位值写入前执行校验
|
||||
2. 管理槽位值的来源和置信度
|
||||
3. 提供槽位写入的统一接口
|
||||
4. 返回校验失败时的追问提示
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import SlotDefinition
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_validation_service import (
|
||||
BatchValidationResult,
|
||||
SlotValidationError,
|
||||
SlotValidationService,
|
||||
)
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlotWriteResult:
|
||||
"""
|
||||
槽位写入结果
|
||||
|
||||
Attributes:
|
||||
success: 是否成功(校验通过并写入)
|
||||
slot_key: 槽位键名
|
||||
value: 最终写入的值
|
||||
error: 校验错误信息(校验失败时)
|
||||
ask_back_prompt: 追问提示语(校验失败时)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
success: bool,
|
||||
slot_key: str,
|
||||
value: Any | None = None,
|
||||
error: SlotValidationError | None = None,
|
||||
ask_back_prompt: str | None = None,
|
||||
):
|
||||
self.success = success
|
||||
self.slot_key = slot_key
|
||||
self.value = value
|
||||
self.error = error
|
||||
self.ask_back_prompt = ask_back_prompt
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"slot_key": self.slot_key,
|
||||
"value": self.value,
|
||||
}
|
||||
if self.error:
|
||||
result["error"] = {
|
||||
"error_code": self.error.error_code,
|
||||
"error_message": self.error.error_message,
|
||||
}
|
||||
if self.ask_back_prompt:
|
||||
result["ask_back_prompt"] = self.ask_back_prompt
|
||||
return result
|
||||
|
||||
|
||||
class SlotManager:
|
||||
"""
|
||||
槽位管理器
|
||||
|
||||
统一槽位写入入口,在写入前执行校验。
|
||||
支持从 SlotDefinition 加载校验规则并执行。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession | None = None,
|
||||
validation_service: SlotValidationService | None = None,
|
||||
slot_def_service: SlotDefinitionService | None = None,
|
||||
):
|
||||
"""
|
||||
初始化槽位管理器
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
validation_service: 校验服务实例
|
||||
slot_def_service: 槽位定义服务实例
|
||||
"""
|
||||
self._session = session
|
||||
self._validation_service = validation_service or SlotValidationService()
|
||||
self._slot_def_service = slot_def_service
|
||||
self._slot_def_cache: dict[str, SlotDefinition | None] = {}
|
||||
|
||||
async def write_slot(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
source: SlotSource = SlotSource.USER_CONFIRMED,
|
||||
confidence: float = 1.0,
|
||||
skip_validation: bool = False,
|
||||
) -> SlotWriteResult:
|
||||
"""
|
||||
写入单个槽位值(带校验)
|
||||
|
||||
执行流程:
|
||||
1. 加载槽位定义
|
||||
2. 执行校验(如果未跳过)
|
||||
3. 返回校验结果
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
value: 槽位值
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
skip_validation: 是否跳过校验(用于特殊场景)
|
||||
|
||||
Returns:
|
||||
SlotWriteResult: 写入结果
|
||||
"""
|
||||
# 加载槽位定义
|
||||
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
||||
|
||||
# 如果没有定义且非跳过校验,允许写入(动态槽位)
|
||||
if slot_def is None and skip_validation:
|
||||
logger.debug(
|
||||
f"[SlotManager] Writing slot without definition: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return SlotWriteResult(
|
||||
success=True,
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
if slot_def is None:
|
||||
# 未定义槽位,允许写入但记录日志
|
||||
logger.info(
|
||||
f"[SlotManager] Slot definition not found, allowing write: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return SlotWriteResult(
|
||||
success=True,
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# 执行校验
|
||||
if not skip_validation:
|
||||
validation_result = self._validation_service.validate_slot_value(
|
||||
slot_def, value, tenant_id
|
||||
)
|
||||
|
||||
if not validation_result.ok:
|
||||
logger.info(
|
||||
f"[SlotManager] Slot validation failed: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, "
|
||||
f"error_code={validation_result.error_code}"
|
||||
)
|
||||
return SlotWriteResult(
|
||||
success=False,
|
||||
slot_key=slot_key,
|
||||
error=SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=validation_result.error_code or "VALIDATION_FAILED",
|
||||
error_message=validation_result.error_message or "校验失败",
|
||||
ask_back_prompt=validation_result.ask_back_prompt,
|
||||
),
|
||||
ask_back_prompt=validation_result.ask_back_prompt,
|
||||
)
|
||||
|
||||
# 使用归一化后的值
|
||||
value = validation_result.normalized_value
|
||||
|
||||
logger.debug(
|
||||
f"[SlotManager] Slot validation passed: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
|
||||
return SlotWriteResult(
|
||||
success=True,
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
async def write_slots(
|
||||
self,
|
||||
tenant_id: str,
|
||||
values: dict[str, Any],
|
||||
source: SlotSource = SlotSource.USER_CONFIRMED,
|
||||
confidence: float = 1.0,
|
||||
skip_validation: bool = False,
|
||||
) -> BatchValidationResult:
|
||||
"""
|
||||
批量写入槽位值(带校验)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
values: 槽位值字典 {slot_key: value}
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
skip_validation: 是否跳过校验
|
||||
|
||||
Returns:
|
||||
BatchValidationResult: 批量校验结果
|
||||
"""
|
||||
if skip_validation:
|
||||
return BatchValidationResult(
|
||||
ok=True,
|
||||
validated_values=values,
|
||||
)
|
||||
|
||||
# 加载所有相关槽位定义
|
||||
slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys()))
|
||||
|
||||
# 执行批量校验
|
||||
result = self._validation_service.validate_slots(
|
||||
slot_defs, values, tenant_id
|
||||
)
|
||||
|
||||
if not result.ok:
|
||||
logger.info(
|
||||
f"[SlotManager] Batch slot validation failed: "
|
||||
f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[SlotManager] Batch slot validation passed: "
|
||||
f"tenant_id={tenant_id}, slots={list(values.keys())}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def validate_before_write(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
) -> tuple[bool, SlotValidationError | None]:
|
||||
"""
|
||||
在写入前预校验槽位值
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
value: 槽位值
|
||||
|
||||
Returns:
|
||||
Tuple of (是否通过, 错误信息)
|
||||
"""
|
||||
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
||||
|
||||
if slot_def is None:
|
||||
# 未定义槽位,视为通过
|
||||
return True, None
|
||||
|
||||
result = self._validation_service.validate_slot_value(
|
||||
slot_def, value, tenant_id
|
||||
)
|
||||
|
||||
if result.ok:
|
||||
return True, None
|
||||
|
||||
return False, SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=result.error_code or "VALIDATION_FAILED",
|
||||
error_message=result.error_message or "校验失败",
|
||||
ask_back_prompt=result.ask_back_prompt,
|
||||
)
|
||||
|
||||
async def get_ask_back_prompt(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
获取槽位的追问提示语
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
追问提示语或 None
|
||||
"""
|
||||
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
||||
if slot_def is None:
|
||||
return None
|
||||
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
return slot_def.ask_back_prompt
|
||||
return slot_def.get("ask_back_prompt")
|
||||
|
||||
async def _get_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
) -> SlotDefinition | dict[str, Any] | None:
|
||||
"""
|
||||
获取槽位定义(带缓存)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
槽位定义或 None
|
||||
"""
|
||||
cache_key = f"{tenant_id}:{slot_key}"
|
||||
|
||||
if cache_key in self._slot_def_cache:
|
||||
return self._slot_def_cache[cache_key]
|
||||
|
||||
slot_def = None
|
||||
if self._slot_def_service:
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
tenant_id, slot_key
|
||||
)
|
||||
elif self._session:
|
||||
service = SlotDefinitionService(self._session)
|
||||
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
|
||||
|
||||
self._slot_def_cache[cache_key] = slot_def
|
||||
return slot_def
|
||||
|
||||
async def _get_slot_definitions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_keys: list[str],
|
||||
) -> list[SlotDefinition | dict[str, Any]]:
|
||||
"""
|
||||
批量获取槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_keys: 槽位键名列表
|
||||
|
||||
Returns:
|
||||
槽位定义列表
|
||||
"""
|
||||
slot_defs = []
|
||||
for key in slot_keys:
|
||||
slot_def = await self._get_slot_definition(tenant_id, key)
|
||||
if slot_def:
|
||||
slot_defs.append(slot_def)
|
||||
return slot_defs
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除槽位定义缓存"""
|
||||
self._slot_def_cache.clear()
|
||||
|
||||
|
||||
def create_slot_manager(
|
||||
session: AsyncSession | None = None,
|
||||
validation_service: SlotValidationService | None = None,
|
||||
slot_def_service: SlotDefinitionService | None = None,
|
||||
) -> SlotManager:
|
||||
"""
|
||||
创建槽位管理器实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
validation_service: 校验服务实例
|
||||
slot_def_service: 槽位定义服务实例
|
||||
|
||||
Returns:
|
||||
SlotManager: 槽位管理器实例
|
||||
"""
|
||||
return SlotManager(
|
||||
session=session,
|
||||
validation_service=validation_service,
|
||||
slot_def_service=slot_def_service,
|
||||
)
|
||||
|
|
@ -0,0 +1,562 @@
|
|||
"""
|
||||
Slot State Aggregator Service.
|
||||
槽位状态聚合服务 - 统一维护本轮槽位状态
|
||||
|
||||
职责:
|
||||
1. 聚合来自 memory_recall 的槽位值
|
||||
2. 叠加本轮输入的槽位值
|
||||
3. 识别缺失的必填槽位
|
||||
4. 支持槽位与元数据字段的关联映射
|
||||
5. 为 KB 检索过滤提供统一的槽位值来源
|
||||
6. [AC-MRS-SLOT-CACHE-01] 多轮状态持久化
|
||||
|
||||
[AC-MRS-SLOT-META-01] 槽位与元数据关联机制
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import MemorySlot, SlotSource
|
||||
from app.services.cache.slot_state_cache import (
|
||||
CachedSlotState,
|
||||
CachedSlotValue,
|
||||
get_slot_state_cache,
|
||||
)
|
||||
from app.services.mid.slot_manager import SlotManager
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotState:
|
||||
"""
|
||||
槽位状态聚合结果
|
||||
|
||||
Attributes:
|
||||
filled_slots: 已填充的槽位值字典 {slot_key: value}
|
||||
missing_required_slots: 缺失的必填槽位列表
|
||||
slot_sources: 槽位值来源字典 {slot_key: source}
|
||||
slot_confidence: 槽位置信度字典 {slot_key: confidence}
|
||||
slot_to_field_map: 槽位到元数据字段的映射 {slot_key: field_key}
|
||||
"""
|
||||
filled_slots: dict[str, Any] = field(default_factory=dict)
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
slot_sources: dict[str, str] = field(default_factory=dict)
|
||||
slot_confidence: dict[str, float] = field(default_factory=dict)
|
||||
slot_to_field_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def get_value_for_filter(self, field_key: str) -> Any:
|
||||
"""
|
||||
获取用于 KB 过滤的字段值
|
||||
|
||||
优先从 slot_to_field_map 反向查找
|
||||
"""
|
||||
# 直接匹配
|
||||
if field_key in self.filled_slots:
|
||||
return self.filled_slots[field_key]
|
||||
|
||||
# 通过 slot_to_field_map 反向查找
|
||||
for slot_key, mapped_field_key in self.slot_to_field_map.items():
|
||||
if mapped_field_key == field_key and slot_key in self.filled_slots:
|
||||
return self.filled_slots[slot_key]
|
||||
|
||||
return None
|
||||
|
||||
def to_debug_info(self) -> dict[str, Any]:
|
||||
"""转换为调试信息字典"""
|
||||
return {
|
||||
"filled_slots": self.filled_slots,
|
||||
"missing_required_slots": self.missing_required_slots,
|
||||
"slot_sources": self.slot_sources,
|
||||
"slot_to_field_map": self.slot_to_field_map,
|
||||
}
|
||||
|
||||
|
||||
class SlotStateAggregator:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-01] 槽位状态聚合器
|
||||
|
||||
统一维护本轮槽位状态,支持:
|
||||
- 从 memory_recall 初始化槽位
|
||||
- 叠加本轮输入的槽位值
|
||||
- 识别缺失的必填槽位
|
||||
- 建立槽位与元数据字段的关联
|
||||
- [AC-MRS-SLOT-CACHE-01] 多轮状态持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
slot_manager: SlotManager | None = None,
|
||||
session_id: str | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._slot_manager = slot_manager or SlotManager(session=session)
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._cache = get_slot_state_cache()
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
memory_slots: dict[str, MemorySlot] | None = None,
|
||||
current_input_slots: dict[str, Any] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
use_cache: bool = True,
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> SlotState:
|
||||
"""
|
||||
聚合槽位状态
|
||||
|
||||
执行流程:
|
||||
1. [AC-MRS-SLOT-CACHE-01] 从缓存加载已有状态
|
||||
2. 从 memory_slots 初始化已填充槽位
|
||||
3. 叠加 current_input_slots(优先级更高)
|
||||
4. 从 context 提取槽位值
|
||||
5. 识别缺失的必填槽位
|
||||
6. 建立槽位与元数据字段的关联映射
|
||||
7. [AC-MRS-SLOT-CACHE-01] 回写缓存
|
||||
|
||||
Args:
|
||||
memory_slots: 从 memory_recall 召回的槽位
|
||||
current_input_slots: 本轮输入的槽位值
|
||||
context: 上下文信息,可能包含槽位值
|
||||
use_cache: 是否使用缓存(默认 True)
|
||||
|
||||
Returns:
|
||||
SlotState: 聚合后的槽位状态
|
||||
"""
|
||||
state = SlotState()
|
||||
|
||||
# [AC-MRS-SLOT-CACHE-01] 1. 从缓存加载已有状态
|
||||
cached_state = None
|
||||
if use_cache and self._session_id:
|
||||
cached_state = await self._cache.get(self._tenant_id, self._session_id)
|
||||
if cached_state:
|
||||
for slot_key, cached_value in cached_state.filled_slots.items():
|
||||
state.filled_slots[slot_key] = cached_value.value
|
||||
state.slot_sources[slot_key] = cached_value.source
|
||||
state.slot_confidence[slot_key] = cached_value.confidence
|
||||
state.slot_to_field_map = cached_state.slot_to_field_map.copy()
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-CACHE-01] Loaded from cache: "
|
||||
f"tenant={self._tenant_id}, session={self._session_id}, "
|
||||
f"slots={list(state.filled_slots.keys())}"
|
||||
)
|
||||
|
||||
# 2. 从 memory_slots 初始化
|
||||
if memory_slots:
|
||||
for slot_key, memory_slot in memory_slots.items():
|
||||
state.filled_slots[slot_key] = memory_slot.value
|
||||
state.slot_sources[slot_key] = memory_slot.source.value
|
||||
state.slot_confidence[slot_key] = memory_slot.confidence
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Initialized from memory: "
|
||||
f"tenant={self._tenant_id}, slots={list(memory_slots.keys())}"
|
||||
)
|
||||
|
||||
# 3. 叠加本轮输入(优先级更高)
|
||||
if current_input_slots:
|
||||
for slot_key, value in current_input_slots.items():
|
||||
if value is not None:
|
||||
state.filled_slots[slot_key] = value
|
||||
state.slot_sources[slot_key] = SlotSource.USER_CONFIRMED.value
|
||||
state.slot_confidence[slot_key] = 1.0
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Merged current input: "
|
||||
f"tenant={self._tenant_id}, slots={list(current_input_slots.keys())}"
|
||||
)
|
||||
|
||||
# 4. 从 context 提取槽位值(优先级最低)
|
||||
if context:
|
||||
context_slots = self._extract_slots_from_context(context)
|
||||
for slot_key, value in context_slots.items():
|
||||
if slot_key not in state.filled_slots and value is not None:
|
||||
state.filled_slots[slot_key] = value
|
||||
state.slot_sources[slot_key] = "context"
|
||||
state.slot_confidence[slot_key] = 0.5
|
||||
|
||||
# 5. 加载槽位定义并建立关联
|
||||
await self._build_slot_mappings(state)
|
||||
|
||||
# 6. 识别缺失的必填槽位
|
||||
await self._identify_missing_required_slots(state, scene_slot_context)
|
||||
|
||||
# [AC-MRS-SLOT-CACHE-01] 7. 回写缓存
|
||||
if use_cache and self._session_id:
|
||||
await self._save_to_cache(state)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Slot state aggregated: "
|
||||
f"tenant={self._tenant_id}, filled={len(state.filled_slots)}, "
|
||||
f"missing={len(state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
async def _save_to_cache(self, state: SlotState) -> None:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 保存槽位状态到缓存
|
||||
"""
|
||||
if not self._session_id:
|
||||
return
|
||||
|
||||
cached_slots = {}
|
||||
for slot_key, value in state.filled_slots.items():
|
||||
source = state.slot_sources.get(slot_key, "unknown")
|
||||
confidence = state.slot_confidence.get(slot_key, 1.0)
|
||||
cached_slots[slot_key] = CachedSlotValue(
|
||||
value=value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
cached_state = CachedSlotState(
|
||||
filled_slots=cached_slots,
|
||||
slot_to_field_map=state.slot_to_field_map.copy(),
|
||||
)
|
||||
|
||||
await self._cache.set(self._tenant_id, self._session_id, cached_state)
|
||||
logger.debug(
|
||||
f"[AC-MRS-SLOT-CACHE-01] Saved to cache: "
|
||||
f"tenant={self._tenant_id}, session={self._session_id}"
|
||||
)
|
||||
|
||||
async def update_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
source: str = "user_confirmed",
|
||||
confidence: float = 1.0,
|
||||
) -> SlotState | None:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 更新单个槽位值并保存到缓存
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
value: 槽位值
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
|
||||
Returns:
|
||||
更新后的槽位状态,如果没有 session_id 则返回 None
|
||||
"""
|
||||
if not self._session_id:
|
||||
return None
|
||||
|
||||
cached_value = CachedSlotValue(
|
||||
value=value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
cached_state = await self._cache.merge_and_set(
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
new_slots={slot_key: cached_value},
|
||||
)
|
||||
|
||||
state = SlotState()
|
||||
state.filled_slots = cached_state.get_simple_filled_slots()
|
||||
state.slot_sources = cached_state.get_slot_sources()
|
||||
state.slot_confidence = cached_state.get_slot_confidence()
|
||||
state.slot_to_field_map = cached_state.slot_to_field_map.copy()
|
||||
|
||||
await self._identify_missing_required_slots(state)
|
||||
|
||||
return state
|
||||
|
||||
async def clear_slot(self, slot_key: str) -> bool:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 清除单个槽位值
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._session_id:
|
||||
return False
|
||||
|
||||
return await self._cache.clear_slot(
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
slot_key=slot_key,
|
||||
)
|
||||
|
||||
async def clear_all_slots(self) -> bool:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 清除所有槽位状态
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._session_id:
|
||||
return False
|
||||
|
||||
return await self._cache.delete(
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
|
||||
def _extract_slots_from_context(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
从上下文中提取可能的槽位值
|
||||
|
||||
常见的槽位值可能存在于:
|
||||
- scene
|
||||
- product_line
|
||||
- region
|
||||
- grade
|
||||
等字段
|
||||
"""
|
||||
slots = {}
|
||||
slot_candidates = [
|
||||
"scene", "product_line", "region", "grade",
|
||||
"category", "type", "status", "priority"
|
||||
]
|
||||
|
||||
for key in slot_candidates:
|
||||
if key in context and context[key] is not None:
|
||||
slots[key] = context[key]
|
||||
|
||||
return slots
|
||||
|
||||
async def _build_slot_mappings(self, state: SlotState) -> None:
|
||||
"""
|
||||
建立槽位与元数据字段的关联映射
|
||||
|
||||
通过 linked_field_id 关联 SlotDefinition 和 MetadataFieldDefinition
|
||||
"""
|
||||
try:
|
||||
# 获取所有槽位定义
|
||||
slot_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
for slot_def in slot_defs:
|
||||
if slot_def.linked_field_id:
|
||||
# 获取关联的元数据字段
|
||||
from app.services.metadata_field_definition_service import (
|
||||
MetadataFieldDefinitionService
|
||||
)
|
||||
field_service = MetadataFieldDefinitionService(self._session)
|
||||
linked_field = await field_service.get_field_definition(
|
||||
tenant_id=self._tenant_id,
|
||||
field_id=str(slot_def.linked_field_id)
|
||||
)
|
||||
|
||||
if linked_field:
|
||||
state.slot_to_field_map[slot_def.slot_key] = linked_field.field_key
|
||||
|
||||
# 检查类型一致性并告警
|
||||
await self._check_type_consistency(
|
||||
slot_def, linked_field, state
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Failed to build slot mappings: {e}"
|
||||
)
|
||||
|
||||
async def _check_type_consistency(
|
||||
self,
|
||||
slot_def: Any,
|
||||
linked_field: Any,
|
||||
state: SlotState,
|
||||
) -> None:
|
||||
"""
|
||||
检查槽位与关联元数据字段的类型一致性
|
||||
|
||||
当不一致时记录告警日志(不强拦截)
|
||||
"""
|
||||
# 检查类型一致性
|
||||
if slot_def.type != linked_field.type:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Type mismatch: "
|
||||
f"slot='{slot_def.slot_key}' type={slot_def.type} vs "
|
||||
f"field='{linked_field.field_key}' type={linked_field.type}"
|
||||
)
|
||||
|
||||
# 检查 required 一致性
|
||||
if slot_def.required != linked_field.required:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Required mismatch: "
|
||||
f"slot='{slot_def.slot_key}' required={slot_def.required} vs "
|
||||
f"field='{linked_field.field_key}' required={linked_field.required}"
|
||||
)
|
||||
|
||||
# 检查 options 一致性(对于 enum/array_enum 类型)
|
||||
if slot_def.type in ["enum", "array_enum"]:
|
||||
slot_options = set(slot_def.options or [])
|
||||
field_options = set(linked_field.options or [])
|
||||
if slot_options != field_options:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Options mismatch: "
|
||||
f"slot='{slot_def.slot_key}' options={slot_options} vs "
|
||||
f"field='{linked_field.field_key}' options={field_options}"
|
||||
)
|
||||
|
||||
async def _identify_missing_required_slots(
|
||||
self,
|
||||
state: SlotState,
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> None:
|
||||
"""
|
||||
识别缺失的必填槽位
|
||||
|
||||
基于 SlotDefinition 的 required 字段和 linked_field 的 required 字段
|
||||
[AC-SCENE-SLOT-02] 当有场景槽位上下文时,优先使用场景定义的必填槽位
|
||||
"""
|
||||
try:
|
||||
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景定义的必填槽位
|
||||
if scene_slot_context:
|
||||
scene_required_keys = set(scene_slot_context.get_required_slot_keys())
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-02] Using scene required slots: "
|
||||
f"scene={scene_slot_context.scene_key}, "
|
||||
f"required_keys={scene_required_keys}"
|
||||
)
|
||||
|
||||
# 获取场景中定义的所有槽位
|
||||
all_slot_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id,
|
||||
)
|
||||
slot_def_map = {slot.slot_key: slot for slot in all_slot_defs}
|
||||
|
||||
for slot_key in scene_required_keys:
|
||||
if slot_key not in state.filled_slots:
|
||||
slot_def = slot_def_map.get(slot_key)
|
||||
ask_back_prompt = slot_def.ask_back_prompt if slot_def else None
|
||||
|
||||
if not ask_back_prompt:
|
||||
ask_back_prompt = f"请提供{slot_key}信息"
|
||||
|
||||
missing_info = {
|
||||
"slot_key": slot_key,
|
||||
"label": slot_key,
|
||||
"reason": "scene_required_slot_missing",
|
||||
"ask_back_prompt": ask_back_prompt,
|
||||
"scene": scene_slot_context.scene_key,
|
||||
}
|
||||
|
||||
if slot_def and slot_def.linked_field_id:
|
||||
missing_info["linked_field_id"] = str(slot_def.linked_field_id)
|
||||
|
||||
state.missing_required_slots.append(missing_info)
|
||||
|
||||
return
|
||||
|
||||
# 获取所有 required 的槽位定义
|
||||
required_slot_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id,
|
||||
required=True,
|
||||
)
|
||||
|
||||
for slot_def in required_slot_defs:
|
||||
if slot_def.slot_key not in state.filled_slots:
|
||||
# 获取追问提示
|
||||
ask_back_prompt = slot_def.ask_back_prompt
|
||||
|
||||
# 如果没有配置追问提示,使用通用模板
|
||||
if not ask_back_prompt:
|
||||
ask_back_prompt = f"请提供{slot_def.slot_key}信息"
|
||||
|
||||
missing_info = {
|
||||
"slot_key": slot_def.slot_key,
|
||||
"label": slot_def.slot_key,
|
||||
"reason": "required_slot_missing",
|
||||
"ask_back_prompt": ask_back_prompt,
|
||||
}
|
||||
|
||||
# 如果有关联字段,使用字段的 label
|
||||
if slot_def.linked_field_id:
|
||||
from app.services.metadata_field_definition_service import (
|
||||
MetadataFieldDefinitionService
|
||||
)
|
||||
field_service = MetadataFieldDefinitionService(self._session)
|
||||
linked_field = await field_service.get_field_definition(
|
||||
tenant_id=self._tenant_id,
|
||||
field_id=str(slot_def.linked_field_id)
|
||||
)
|
||||
if linked_field:
|
||||
missing_info["label"] = linked_field.label
|
||||
missing_info["field_key"] = linked_field.field_key
|
||||
|
||||
state.missing_required_slots.append(missing_info)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Missing required slot: "
|
||||
f"slot_key={slot_def.slot_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Failed to identify missing slots: {e}"
|
||||
)
|
||||
|
||||
async def generate_ask_back_response(
|
||||
self,
|
||||
state: SlotState,
|
||||
missing_slot_key: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
生成追问响应文案
|
||||
|
||||
Args:
|
||||
state: 当前槽位状态
|
||||
missing_slot_key: 指定要追问的槽位键名,为 None 时追问第一个缺失槽位
|
||||
|
||||
Returns:
|
||||
追问文案或 None(如果没有缺失槽位)
|
||||
"""
|
||||
if not state.missing_required_slots:
|
||||
return None
|
||||
|
||||
missing_info = None
|
||||
|
||||
# 如果指定了槽位键名,查找对应的追问提示
|
||||
if missing_slot_key:
|
||||
for missing in state.missing_required_slots:
|
||||
if missing.get("slot_key") == missing_slot_key:
|
||||
missing_info = missing
|
||||
break
|
||||
else:
|
||||
# 使用第一个缺失槽位
|
||||
missing_info = state.missing_required_slots[0]
|
||||
|
||||
if missing_info is None:
|
||||
return None
|
||||
|
||||
# 优先使用配置的 ask_back_prompt
|
||||
ask_back_prompt = missing_info.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
return ask_back_prompt
|
||||
|
||||
# 使用通用模板
|
||||
label = missing_info.get("label", missing_info.get("slot_key", "相关信息"))
|
||||
return f"为了更好地为您提供帮助,请告诉我您的{label}。"
|
||||
|
||||
|
||||
def create_slot_state_aggregator(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
) -> SlotStateAggregator:
|
||||
"""
|
||||
创建槽位状态聚合器实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
SlotStateAggregator: 槽位状态聚合器实例
|
||||
"""
|
||||
return SlotStateAggregator(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
"""
|
||||
Slot Strategy Executor.
|
||||
[AC-MRS-07-UPGRADE] 槽位提取策略链执行器
|
||||
|
||||
按顺序执行提取策略链,直到成功提取并通过校验。
|
||||
支持失败分类和详细日志追踪。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.models.entities import ExtractFailureType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractStrategyType(str, Enum):
|
||||
"""提取策略类型"""
|
||||
RULE = "rule"
|
||||
LLM = "llm"
|
||||
USER_INPUT = "user_input"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyStepResult:
|
||||
"""单步策略执行结果"""
|
||||
strategy: str
|
||||
success: bool
|
||||
value: Any = None
|
||||
failure_type: ExtractFailureType | None = None
|
||||
failure_reason: str = ""
|
||||
execution_time_ms: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyChainResult:
|
||||
"""策略链执行结果"""
|
||||
slot_key: str
|
||||
success: bool
|
||||
final_value: Any = None
|
||||
final_strategy: str | None = None
|
||||
steps: list[StrategyStepResult] = field(default_factory=list)
|
||||
total_execution_time_ms: float = 0.0
|
||||
ask_back_prompt: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"slot_key": self.slot_key,
|
||||
"success": self.success,
|
||||
"final_value": self.final_value,
|
||||
"final_strategy": self.final_strategy,
|
||||
"steps": [
|
||||
{
|
||||
"strategy": step.strategy,
|
||||
"success": step.success,
|
||||
"value": step.value if step.success else None,
|
||||
"failure_type": step.failure_type.value if step.failure_type else None,
|
||||
"failure_reason": step.failure_reason,
|
||||
"execution_time_ms": step.execution_time_ms,
|
||||
}
|
||||
for step in self.steps
|
||||
],
|
||||
"total_execution_time_ms": self.total_execution_time_ms,
|
||||
"ask_back_prompt": self.ask_back_prompt,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractContext:
|
||||
"""提取上下文"""
|
||||
tenant_id: str
|
||||
slot_key: str
|
||||
user_input: str
|
||||
slot_type: str
|
||||
validation_rule: str | None = None
|
||||
history: list[dict[str, str]] | None = None
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class SlotStrategyExecutor:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 槽位提取策略链执行器
|
||||
|
||||
职责:
|
||||
1. 按策略链顺序执行提取
|
||||
2. 某一步成功且校验通过 -> 停止并返回结果
|
||||
3. 当前策略失败 -> 记录失败原因,继续下一策略
|
||||
4. 全部失败 -> 返回结构化失败结果
|
||||
5. 提供可追踪日志(slot_key、strategy、reason)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rule_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
llm_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
user_input_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
):
|
||||
"""
|
||||
初始化执行器
|
||||
|
||||
Args:
|
||||
rule_extractor: 规则提取器函数
|
||||
llm_extractor: LLM提取器函数
|
||||
user_input_extractor: 用户输入提取器函数
|
||||
"""
|
||||
self._extractors: dict[str, Callable[[ExtractContext], Any]] = {
|
||||
ExtractStrategyType.RULE.value: rule_extractor or self._default_rule_extract,
|
||||
ExtractStrategyType.LLM.value: llm_extractor or self._default_llm_extract,
|
||||
ExtractStrategyType.USER_INPUT.value: user_input_extractor or self._default_user_input_extract,
|
||||
}
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
strategies: list[str],
|
||||
context: ExtractContext,
|
||||
ask_back_prompt: str | None = None,
|
||||
) -> StrategyChainResult:
|
||||
"""
|
||||
执行提取策略链
|
||||
|
||||
Args:
|
||||
strategies: 策略链,如 ["user_input", "rule", "llm"]
|
||||
context: 提取上下文
|
||||
ask_back_prompt: 追问提示语(全部失败时使用)
|
||||
|
||||
Returns:
|
||||
StrategyChainResult: 执行结果
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
steps: list[StrategyStepResult] = []
|
||||
|
||||
logger.info(
|
||||
f"[SlotStrategyExecutor] Starting strategy chain for slot '{context.slot_key}': "
|
||||
f"strategies={strategies}, tenant={context.tenant_id}"
|
||||
)
|
||||
|
||||
for idx, strategy in enumerate(strategies):
|
||||
step_start = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[SlotStrategyExecutor] Executing step {idx + 1}/{len(strategies)}: "
|
||||
f"slot_key={context.slot_key}, strategy={strategy}"
|
||||
)
|
||||
|
||||
step_result = await self._execute_single_strategy(strategy, context)
|
||||
step_result.execution_time_ms = (time.time() - step_start) * 1000
|
||||
steps.append(step_result)
|
||||
|
||||
if step_result.success:
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[SlotStrategyExecutor] Strategy chain succeeded at step {idx + 1}: "
|
||||
f"slot_key={context.slot_key}, strategy={strategy}, "
|
||||
f"total_time_ms={total_time:.2f}"
|
||||
)
|
||||
return StrategyChainResult(
|
||||
slot_key=context.slot_key,
|
||||
success=True,
|
||||
final_value=step_result.value,
|
||||
final_strategy=strategy,
|
||||
steps=steps,
|
||||
total_execution_time_ms=total_time,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SlotStrategyExecutor] Step {idx + 1} failed: "
|
||||
f"slot_key={context.slot_key}, strategy={strategy}, "
|
||||
f"failure_type={step_result.failure_type}, "
|
||||
f"reason={step_result.failure_reason}"
|
||||
)
|
||||
|
||||
# 全部策略失败
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
logger.warning(
|
||||
f"[SlotStrategyExecutor] All strategies failed for slot '{context.slot_key}': "
|
||||
f"attempted={len(strategies)}, total_time_ms={total_time:.2f}"
|
||||
)
|
||||
|
||||
return StrategyChainResult(
|
||||
slot_key=context.slot_key,
|
||||
success=False,
|
||||
final_value=None,
|
||||
final_strategy=None,
|
||||
steps=steps,
|
||||
total_execution_time_ms=total_time,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
async def _execute_single_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
context: ExtractContext,
|
||||
) -> StrategyStepResult:
|
||||
"""
|
||||
执行单个提取策略
|
||||
|
||||
Args:
|
||||
strategy: 策略类型
|
||||
context: 提取上下文
|
||||
|
||||
Returns:
|
||||
StrategyStepResult: 单步执行结果
|
||||
"""
|
||||
extractor = self._extractors.get(strategy)
|
||||
|
||||
if not extractor:
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR,
|
||||
failure_reason=f"Unknown strategy: {strategy}",
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行提取
|
||||
value = await extractor(context)
|
||||
|
||||
# 检查结果是否为空
|
||||
if value is None or value == "":
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_EMPTY,
|
||||
failure_reason="Extracted value is empty",
|
||||
)
|
||||
|
||||
# 执行校验(如果有校验规则)
|
||||
if context.validation_rule:
|
||||
is_valid, error_msg = self._validate_value(value, context)
|
||||
if not is_valid:
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_VALIDATION_FAIL,
|
||||
failure_reason=f"Validation failed: {error_msg}",
|
||||
)
|
||||
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=True,
|
||||
value=value,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[SlotStrategyExecutor] Runtime error in strategy '{strategy}' "
|
||||
f"for slot '{context.slot_key}': {e}"
|
||||
)
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR,
|
||||
failure_reason=f"Runtime error: {str(e)}",
|
||||
)
|
||||
|
||||
def _validate_value(self, value: Any, context: ExtractContext) -> tuple[bool, str]:
|
||||
"""
|
||||
校验提取的值
|
||||
|
||||
Args:
|
||||
value: 提取的值
|
||||
context: 提取上下文
|
||||
|
||||
Returns:
|
||||
Tuple of (是否通过, 错误信息)
|
||||
"""
|
||||
import re
|
||||
|
||||
validation_rule = context.validation_rule
|
||||
if not validation_rule:
|
||||
return True, ""
|
||||
|
||||
try:
|
||||
# 尝试作为正则表达式校验
|
||||
if validation_rule.startswith("^") or validation_rule.endswith("$"):
|
||||
if re.match(validation_rule, str(value)):
|
||||
return True, ""
|
||||
return False, f"Value '{value}' does not match pattern '{validation_rule}'"
|
||||
|
||||
# 其他校验规则可以在这里扩展
|
||||
return True, ""
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"[SlotStrategyExecutor] Invalid validation rule pattern: {e}")
|
||||
return True, "" # 正则错误时放行
|
||||
|
||||
async def _default_rule_extract(self, context: ExtractContext) -> Any:
|
||||
"""默认规则提取实现(占位)"""
|
||||
# 实际项目中应该调用 VariableExtractor 或其他规则引擎
|
||||
logger.debug(f"[SlotStrategyExecutor] Default rule extract for '{context.slot_key}'")
|
||||
return None
|
||||
|
||||
async def _default_llm_extract(self, context: ExtractContext) -> Any:
|
||||
"""默认LLM提取实现(占位)"""
|
||||
logger.debug(f"[SlotStrategyExecutor] Default LLM extract for '{context.slot_key}'")
|
||||
return None
|
||||
|
||||
async def _default_user_input_extract(self, context: ExtractContext) -> Any:
|
||||
"""默认用户输入提取实现"""
|
||||
# user_input 策略通常表示需要向用户询问,这里返回空表示需要追问
|
||||
logger.debug(f"[SlotStrategyExecutor] User input required for '{context.slot_key}'")
|
||||
return None
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def execute_extract_strategies(
|
||||
strategies: list[str],
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
user_input: str,
|
||||
slot_type: str = "string",
|
||||
validation_rule: str | None = None,
|
||||
ask_back_prompt: str | None = None,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
rule_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
llm_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
) -> StrategyChainResult:
|
||||
"""
|
||||
便捷函数:执行提取策略链
|
||||
|
||||
Args:
|
||||
strategies: 策略链
|
||||
tenant_id: 租户ID
|
||||
slot_key: 槽位键名
|
||||
user_input: 用户输入
|
||||
slot_type: 槽位类型
|
||||
validation_rule: 校验规则
|
||||
ask_back_prompt: 追问提示语
|
||||
history: 对话历史
|
||||
rule_extractor: 规则提取器
|
||||
llm_extractor: LLM提取器
|
||||
|
||||
Returns:
|
||||
StrategyChainResult: 执行结果
|
||||
"""
|
||||
executor = SlotStrategyExecutor(
|
||||
rule_extractor=rule_extractor,
|
||||
llm_extractor=llm_extractor,
|
||||
)
|
||||
|
||||
context = ExtractContext(
|
||||
tenant_id=tenant_id,
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
slot_type=slot_type,
|
||||
validation_rule=validation_rule,
|
||||
history=history,
|
||||
)
|
||||
|
||||
return await executor.execute_chain(strategies, context, ask_back_prompt)
|
||||
|
|
@ -0,0 +1,572 @@
|
|||
"""
|
||||
Slot Validation Service.
|
||||
槽位校验规则 runtime 生效服务
|
||||
|
||||
提供槽位值的运行时校验能力,支持:
|
||||
1. 正则表达式校验
|
||||
2. JSON Schema 校验
|
||||
3. 类型校验(string/number/boolean/enum/array_enum)
|
||||
4. 必填校验
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
from jsonschema.exceptions import ValidationError as JsonSchemaValidationError
|
||||
|
||||
from app.models.entities import SlotDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 错误码定义
|
||||
class SlotValidationErrorCode:
|
||||
"""槽位校验错误码"""
|
||||
|
||||
SLOT_REQUIRED_MISSING = "SLOT_REQUIRED_MISSING"
|
||||
SLOT_TYPE_INVALID = "SLOT_TYPE_INVALID"
|
||||
SLOT_REGEX_MISMATCH = "SLOT_REGEX_MISMATCH"
|
||||
SLOT_JSON_SCHEMA_MISMATCH = "SLOT_JSON_SCHEMA_MISMATCH"
|
||||
SLOT_VALIDATION_RULE_INVALID = "SLOT_VALIDATION_RULE_INVALID"
|
||||
SLOT_ENUM_INVALID = "SLOT_ENUM_INVALID"
|
||||
SLOT_ARRAY_ENUM_INVALID = "SLOT_ARRAY_ENUM_INVALID"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""
|
||||
单个槽位校验结果
|
||||
|
||||
Attributes:
|
||||
ok: 校验是否通过
|
||||
normalized_value: 归一化后的值(如类型转换后)
|
||||
error_code: 错误码(校验失败时)
|
||||
error_message: 错误描述(校验失败时)
|
||||
ask_back_prompt: 追问提示语(校验失败且配置了 ask_back_prompt 时)
|
||||
"""
|
||||
|
||||
ok: bool
|
||||
normalized_value: Any | None = None
|
||||
error_code: str | None = None
|
||||
error_message: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotValidationError:
|
||||
"""
|
||||
槽位校验错误详情
|
||||
|
||||
Attributes:
|
||||
slot_key: 槽位键名
|
||||
error_code: 错误码
|
||||
error_message: 错误描述
|
||||
ask_back_prompt: 追问提示语
|
||||
"""
|
||||
|
||||
slot_key: str
|
||||
error_code: str
|
||||
error_message: str
|
||||
ask_back_prompt: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchValidationResult:
|
||||
"""
|
||||
批量槽位校验结果
|
||||
|
||||
Attributes:
|
||||
ok: 是否全部校验通过
|
||||
errors: 校验错误列表
|
||||
validated_values: 校验通过的值字典
|
||||
"""
|
||||
|
||||
ok: bool
|
||||
errors: list[SlotValidationError] = field(default_factory=list)
|
||||
validated_values: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SlotValidationService:
|
||||
"""
|
||||
槽位校验服务
|
||||
|
||||
负责在槽位值写回前执行校验,支持:
|
||||
- 正则表达式校验
|
||||
- JSON Schema 校验
|
||||
- 类型校验
|
||||
- 必填校验
|
||||
"""
|
||||
|
||||
# 支持的槽位类型
|
||||
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化槽位校验服务"""
|
||||
self._schema_cache: dict[str, dict] = {}
|
||||
|
||||
def validate_slot_value(
|
||||
self,
|
||||
slot_def: dict[str, Any] | SlotDefinition,
|
||||
value: Any,
|
||||
tenant_id: str | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
校验单个槽位值
|
||||
|
||||
校验顺序:
|
||||
1. 必填校验(如果 required=true 且值为空)
|
||||
2. 类型校验
|
||||
3. validation_rule 校验(正则或 JSON Schema)
|
||||
|
||||
Args:
|
||||
slot_def: 槽位定义(dict 或 SlotDefinition 对象)
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID(用于日志记录)
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
# 统一转换为 dict 处理
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
slot_dict = {
|
||||
"slot_key": slot_def.slot_key,
|
||||
"type": slot_def.type,
|
||||
"required": slot_def.required,
|
||||
"validation_rule": slot_def.validation_rule,
|
||||
"ask_back_prompt": slot_def.ask_back_prompt,
|
||||
}
|
||||
else:
|
||||
slot_dict = slot_def
|
||||
|
||||
slot_key = slot_dict.get("slot_key", "unknown")
|
||||
slot_type = slot_dict.get("type", "string")
|
||||
required = slot_dict.get("required", False)
|
||||
validation_rule = slot_dict.get("validation_rule")
|
||||
ask_back_prompt = slot_dict.get("ask_back_prompt")
|
||||
|
||||
# 1. 必填校验
|
||||
if required and self._is_empty_value(value):
|
||||
logger.info(
|
||||
f"[SlotValidation] Required slot missing: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING,
|
||||
error_message=f"槽位 '{slot_key}' 为必填项",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
# 如果值为空且非必填,跳过后续校验
|
||||
if self._is_empty_value(value):
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
# 2. 类型校验
|
||||
type_result = self._validate_type(slot_dict, value, tenant_id)
|
||||
if not type_result.ok:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=type_result.error_code,
|
||||
error_message=type_result.error_message,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
normalized_value = type_result.normalized_value
|
||||
|
||||
# 3. validation_rule 校验
|
||||
if validation_rule and str(validation_rule).strip():
|
||||
rule_result = self._validate_rule(
|
||||
slot_dict, normalized_value, tenant_id
|
||||
)
|
||||
if not rule_result.ok:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=rule_result.error_code,
|
||||
error_message=rule_result.error_message,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
normalized_value = rule_result.normalized_value or normalized_value
|
||||
|
||||
logger.debug(
|
||||
f"[SlotValidation] Slot validation passed: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}"
|
||||
)
|
||||
|
||||
return ValidationResult(ok=True, normalized_value=normalized_value)
|
||||
|
||||
def validate_slots(
|
||||
self,
|
||||
slot_defs: list[dict[str, Any] | SlotDefinition],
|
||||
values: dict[str, Any],
|
||||
tenant_id: str | None = None,
|
||||
) -> BatchValidationResult:
|
||||
"""
|
||||
批量校验多个槽位值
|
||||
|
||||
Args:
|
||||
slot_defs: 槽位定义列表
|
||||
values: 槽位值字典 {slot_key: value}
|
||||
tenant_id: 租户 ID(用于日志记录)
|
||||
|
||||
Returns:
|
||||
BatchValidationResult: 批量校验结果
|
||||
"""
|
||||
errors: list[SlotValidationError] = []
|
||||
validated_values: dict[str, Any] = {}
|
||||
|
||||
# 构建 slot_def 映射
|
||||
slot_def_map: dict[str, dict[str, Any] | SlotDefinition] = {}
|
||||
for slot_def in slot_defs:
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
slot_def_map[slot_def.slot_key] = slot_def
|
||||
else:
|
||||
slot_def_map[slot_def.get("slot_key", "")] = slot_def
|
||||
|
||||
# 校验每个提供的值
|
||||
for slot_key, value in values.items():
|
||||
slot_def = slot_def_map.get(slot_key)
|
||||
if not slot_def:
|
||||
# 未定义槽位,跳过校验(允许动态槽位)
|
||||
validated_values[slot_key] = value
|
||||
continue
|
||||
|
||||
result = self.validate_slot_value(slot_def, value, tenant_id)
|
||||
|
||||
if result.ok:
|
||||
validated_values[slot_key] = result.normalized_value
|
||||
else:
|
||||
errors.append(
|
||||
SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=result.error_code or "UNKNOWN_ERROR",
|
||||
error_message=result.error_message or "校验失败",
|
||||
ask_back_prompt=result.ask_back_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
# 检查必填槽位是否缺失
|
||||
for slot_def in slot_defs:
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
slot_key = slot_def.slot_key
|
||||
required = slot_def.required
|
||||
ask_back_prompt = slot_def.ask_back_prompt
|
||||
else:
|
||||
slot_key = slot_def.get("slot_key", "")
|
||||
required = slot_def.get("required", False)
|
||||
ask_back_prompt = slot_def.get("ask_back_prompt")
|
||||
|
||||
if required and slot_key not in values:
|
||||
# 检查是否已经有该错误
|
||||
if not any(e.slot_key == slot_key for e in errors):
|
||||
errors.append(
|
||||
SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING,
|
||||
error_message=f"槽位 '{slot_key}' 为必填项",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
return BatchValidationResult(
|
||||
ok=len(errors) == 0,
|
||||
errors=errors,
|
||||
validated_values=validated_values,
|
||||
)
|
||||
|
||||
def _is_empty_value(self, value: Any) -> bool:
|
||||
"""判断值是否为空"""
|
||||
if value is None:
|
||||
return True
|
||||
if isinstance(value, str) and not value.strip():
|
||||
return True
|
||||
if isinstance(value, list) and len(value) == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _validate_type(
|
||||
self,
|
||||
slot_def: dict[str, Any],
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
类型校验
|
||||
|
||||
Args:
|
||||
slot_def: 槽位定义字典
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
slot_key = slot_def.get("slot_key", "unknown")
|
||||
slot_type = slot_def.get("type", "string")
|
||||
|
||||
if slot_type not in self.VALID_TYPES:
|
||||
logger.warning(
|
||||
f"[SlotValidation] Unknown slot type: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}"
|
||||
)
|
||||
# 未知类型不阻止,只记录警告
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
try:
|
||||
if slot_type == "string":
|
||||
if not isinstance(value, str):
|
||||
# 尝试转换为字符串
|
||||
normalized = str(value)
|
||||
return ValidationResult(ok=True, normalized_value=normalized)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
elif slot_type == "number":
|
||||
if isinstance(value, bool):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为数字,但得到布尔值",
|
||||
)
|
||||
if isinstance(value, int | float):
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
# 尝试转换为数字
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
if "." in value:
|
||||
normalized = float(value)
|
||||
else:
|
||||
normalized = int(value)
|
||||
return ValidationResult(ok=True, normalized_value=normalized)
|
||||
except ValueError:
|
||||
pass
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为数字",
|
||||
)
|
||||
|
||||
elif slot_type == "boolean":
|
||||
if isinstance(value, bool):
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
if isinstance(value, str):
|
||||
lower_val = value.lower()
|
||||
if lower_val in ("true", "1", "yes", "是", "真"):
|
||||
return ValidationResult(ok=True, normalized_value=True)
|
||||
if lower_val in ("false", "0", "no", "否", "假"):
|
||||
return ValidationResult(ok=True, normalized_value=False)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为布尔值",
|
||||
)
|
||||
|
||||
elif slot_type == "enum":
|
||||
if not isinstance(value, str):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为字符串(枚举)",
|
||||
)
|
||||
# 如果有选项定义,校验值是否在选项中
|
||||
options = slot_def.get("options") or []
|
||||
if options and value not in options:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_ENUM_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的值 '{value}' 不在允许选项 {options} 中",
|
||||
)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
elif slot_type == "array_enum":
|
||||
if not isinstance(value, list):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为数组",
|
||||
)
|
||||
# 校验数组元素
|
||||
options = slot_def.get("options") or []
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的数组元素应为字符串",
|
||||
)
|
||||
if options and item not in options:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的值 '{item}' 不在允许选项 {options} 中",
|
||||
)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SlotValidation] Type validation error: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型校验异常: {str(e)}",
|
||||
)
|
||||
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
def _validate_rule(
|
||||
self,
|
||||
slot_def: dict[str, Any],
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
校验规则校验(正则或 JSON Schema)
|
||||
|
||||
判定逻辑:
|
||||
1. 如果规则是 JSON 对象字符串(以 { 或 [ 开头),按 JSON Schema 处理
|
||||
2. 否则按正则表达式处理
|
||||
|
||||
Args:
|
||||
slot_def: 槽位定义字典
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
slot_key = slot_def.get("slot_key", "unknown")
|
||||
validation_rule = str(slot_def.get("validation_rule", "")).strip()
|
||||
|
||||
if not validation_rule:
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
# 判定是 JSON Schema 还是正则表达式
|
||||
# JSON Schema 通常以 { 或 [ 开头
|
||||
is_json_schema = validation_rule.strip().startswith(("{", "["))
|
||||
|
||||
if is_json_schema:
|
||||
return self._validate_json_schema(
|
||||
slot_key, validation_rule, value, tenant_id
|
||||
)
|
||||
else:
|
||||
return self._validate_regex(
|
||||
slot_key, validation_rule, value, tenant_id
|
||||
)
|
||||
|
||||
def _validate_regex(
|
||||
self,
|
||||
slot_key: str,
|
||||
pattern: str,
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
正则表达式校验
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
pattern: 正则表达式
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
try:
|
||||
# 将值转为字符串进行匹配
|
||||
str_value = str(value) if value is not None else ""
|
||||
|
||||
if not re.search(pattern, str_value):
|
||||
logger.info(
|
||||
f"[SlotValidation] Regex mismatch: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_REGEX_MISMATCH,
|
||||
error_message=f"槽位 '{slot_key}' 的值不符合格式要求",
|
||||
)
|
||||
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(
|
||||
f"[SlotValidation] Invalid regex pattern: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, pattern={pattern}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法正则)",
|
||||
)
|
||||
|
||||
def _validate_json_schema(
|
||||
self,
|
||||
slot_key: str,
|
||||
schema_str: str,
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
JSON Schema 校验
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
schema_str: JSON Schema 字符串
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
try:
|
||||
# 解析 JSON Schema
|
||||
schema = self._schema_cache.get(schema_str)
|
||||
if schema is None:
|
||||
schema = json.loads(schema_str)
|
||||
self._schema_cache[schema_str] = schema
|
||||
|
||||
# 执行校验
|
||||
jsonschema.validate(instance=value, schema=schema)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"[SlotValidation] Invalid JSON schema: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法 JSON)",
|
||||
)
|
||||
|
||||
except JsonSchemaValidationError as e:
|
||||
logger.info(
|
||||
f"[SlotValidation] JSON schema mismatch: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e.message}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH,
|
||||
error_message=f"槽位 '{slot_key}' 的值不符合格式要求: {e.message}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SlotValidation] JSON schema validation error: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的校验规则执行异常: {str(e)}",
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除 JSON Schema 缓存"""
|
||||
self._schema_cache.clear()
|
||||
|
|
@ -127,7 +127,8 @@ class ToolCallRecorder:
|
|||
logger.info(
|
||||
f"[AC-IDMP-15] Tool call recorded: tool={trace.tool_name}, "
|
||||
f"type={trace.tool_type.value}, duration_ms={trace.duration_ms}, "
|
||||
f"status={trace.status.value}, session={session_id}"
|
||||
f"status={trace.status.value}, session={session_id}, "
|
||||
f"args_digest={trace.args_digest}, result_digest={trace.result_digest}"
|
||||
)
|
||||
|
||||
def record_success(
|
||||
|
|
@ -153,6 +154,8 @@ class ToolCallRecorder:
|
|||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
result_digest=ToolCallTrace.compute_digest(result) if result else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
result=result,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
@ -179,6 +182,7 @@ class ToolCallRecorder:
|
|||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
@ -207,6 +211,7 @@ class ToolCallRecorder:
|
|||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
@ -231,6 +236,7 @@ class ToolCallRecorder:
|
|||
error_code=reason,
|
||||
registry_version=registry_version,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
"""
|
||||
Tool definition converter for Function Calling.
|
||||
Converts ToolRegistry definitions to LLM ToolDefinition format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.services.llm.base import ToolDefinition
|
||||
from app.services.mid.tool_registry import ToolDefinition as RegistryToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_tool_to_llm_format(tool: RegistryToolDefinition) -> ToolDefinition:
|
||||
"""
|
||||
Convert ToolRegistry tool definition to LLM ToolDefinition format.
|
||||
|
||||
Args:
|
||||
tool: Tool definition from ToolRegistry
|
||||
|
||||
Returns:
|
||||
ToolDefinition for Function Calling
|
||||
"""
|
||||
meta = tool.metadata or {}
|
||||
parameters = meta.get("parameters", {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
})
|
||||
|
||||
if not isinstance(parameters, dict):
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if "type" not in parameters:
|
||||
parameters["type"] = "object"
|
||||
if "properties" not in parameters:
|
||||
parameters["properties"] = {}
|
||||
if "required" not in parameters:
|
||||
parameters["required"] = []
|
||||
|
||||
properties = parameters.get("properties", {})
|
||||
if "tenant_id" in properties:
|
||||
properties = {k: v for k, v in properties.items() if k != "tenant_id"}
|
||||
if "user_id" in properties:
|
||||
properties = {k: v for k, v in properties.items() if k != "user_id"}
|
||||
if "session_id" in properties:
|
||||
properties = {k: v for k, v in properties.items() if k != "session_id"}
|
||||
|
||||
parameters["properties"] = properties
|
||||
|
||||
required = parameters.get("required", [])
|
||||
required = [r for r in required if r not in ("tenant_id", "user_id", "session_id")]
|
||||
parameters["required"] = required
|
||||
|
||||
return ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
|
||||
def convert_tools_to_llm_format(tools: list[RegistryToolDefinition]) -> list[ToolDefinition]:
|
||||
"""
|
||||
Convert multiple tool definitions to LLM format.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions from ToolRegistry
|
||||
|
||||
Returns:
|
||||
List of ToolDefinition for Function Calling
|
||||
"""
|
||||
return [convert_tool_to_llm_format(tool) for tool in tools]
|
||||
|
||||
|
||||
def build_tool_result_message(
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: dict[str, Any],
|
||||
tool_guide: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Build a tool result message for the conversation.
|
||||
|
||||
Args:
|
||||
tool_call_id: ID of the tool call
|
||||
tool_name: Name of the tool
|
||||
result: Tool execution result
|
||||
tool_guide: Optional tool usage guide to append
|
||||
|
||||
Returns:
|
||||
Message dict with role='tool'
|
||||
"""
|
||||
if isinstance(result, dict):
|
||||
result_copy = {k: v for k, v in result.items() if k != "_tool_guide"}
|
||||
content = str(result_copy)
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
if tool_guide:
|
||||
content = f"{content}\n\n---\n{tool_guide}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
|
@ -0,0 +1,313 @@
|
|||
"""
|
||||
Tool Guide Registry for Mid Platform.
|
||||
Provides tool-based usage guidance with caching support.
|
||||
|
||||
Tool guides are usage manuals for tools, loaded on-demand with metadata scanning.
|
||||
This separates tool definitions (Function Calling) from usage guides (Tool Guides).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOOLS_DIR = Path(__file__).parent.parent.parent.parent / "tools"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolGuideMetadata:
|
||||
"""Lightweight tool guide metadata for quick scanning (~100 tokens)."""
|
||||
name: str
|
||||
description: str
|
||||
triggers: list[str] = field(default_factory=list)
|
||||
anti_triggers: list[str] = field(default_factory=list)
|
||||
tools: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolGuideDefinition:
|
||||
"""Full tool guide definition with complete content."""
|
||||
name: str
|
||||
description: str
|
||||
triggers: list[str]
|
||||
anti_triggers: list[str]
|
||||
tools: list[str]
|
||||
content: str
|
||||
raw_markdown: str
|
||||
|
||||
|
||||
class ToolGuideRegistry:
|
||||
"""
|
||||
Tool guide registry with caching support.
|
||||
|
||||
Features:
|
||||
- Load tool guides from .md files
|
||||
- Cache tool guides in memory for high-frequency access
|
||||
- Provide lightweight metadata for quick scanning
|
||||
- Provide full content on demand
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, tools_dir: Path | None = None):
|
||||
if ToolGuideRegistry._initialized:
|
||||
return
|
||||
|
||||
self._tools_dir = tools_dir or TOOLS_DIR
|
||||
self._tool_guides: dict[str, ToolGuideDefinition] = {}
|
||||
self._metadata_cache: dict[str, ToolGuideMetadata] = {}
|
||||
self._tool_to_guide: dict[str, str] = {}
|
||||
self._loaded = False
|
||||
|
||||
ToolGuideRegistry._initialized = True
|
||||
logger.info(f"[ToolGuideRegistry] Initialized with tools_dir={self._tools_dir}")
|
||||
|
||||
def load_tools(self, force_reload: bool = False) -> None:
|
||||
"""
|
||||
Load all tool guides from tools directory into cache.
|
||||
|
||||
Args:
|
||||
force_reload: Force reload even if already loaded
|
||||
"""
|
||||
if self._loaded and not force_reload:
|
||||
logger.debug("[ToolGuideRegistry] Tool guides already loaded, skipping")
|
||||
return
|
||||
|
||||
if not self._tools_dir.exists():
|
||||
logger.warning(f"[ToolGuideRegistry] Tools directory not found: {self._tools_dir}")
|
||||
return
|
||||
|
||||
self._tool_guides.clear()
|
||||
self._metadata_cache.clear()
|
||||
self._tool_to_guide.clear()
|
||||
|
||||
for tool_file in self._tools_dir.glob("*.md"):
|
||||
try:
|
||||
tool_guide = self._parse_tool_file(tool_file)
|
||||
if tool_guide:
|
||||
self._tool_guides[tool_guide.name] = tool_guide
|
||||
self._metadata_cache[tool_guide.name] = ToolGuideMetadata(
|
||||
name=tool_guide.name,
|
||||
description=tool_guide.description,
|
||||
triggers=tool_guide.triggers,
|
||||
anti_triggers=tool_guide.anti_triggers,
|
||||
tools=tool_guide.tools,
|
||||
)
|
||||
for tool_name in tool_guide.tools:
|
||||
self._tool_to_guide[tool_name] = tool_guide.name
|
||||
logger.info(f"[ToolGuideRegistry] Loaded tool guide: {tool_guide.name} (tools: {tool_guide.tools})")
|
||||
except Exception as e:
|
||||
logger.error(f"[ToolGuideRegistry] Failed to load tool guide from {tool_file}: {e}")
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"[ToolGuideRegistry] Loaded {len(self._tool_guides)} tool guides")
|
||||
|
||||
def _parse_tool_file(self, file_path: Path) -> ToolGuideDefinition | None:
|
||||
"""
|
||||
Parse a tool guide markdown file.
|
||||
|
||||
Expected format:
|
||||
---
|
||||
name: tool_name
|
||||
description: Tool description
|
||||
triggers:
|
||||
- trigger 1
|
||||
- trigger 2
|
||||
anti_triggers:
|
||||
- anti trigger 1
|
||||
tools:
|
||||
- tool_name
|
||||
---
|
||||
|
||||
## Usage Guide
|
||||
...
|
||||
"""
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)$", content, re.DOTALL)
|
||||
if not frontmatter_match:
|
||||
logger.warning(f"[ToolGuideRegistry] No frontmatter found in {file_path}")
|
||||
return None
|
||||
|
||||
frontmatter_text = frontmatter_match.group(1)
|
||||
body = frontmatter_match.group(2)
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
current_key = None
|
||||
current_list: list[str] | None = None
|
||||
|
||||
for line in frontmatter_text.split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
key_match = re.match(r"^(\w+):\s*(.*)$", line)
|
||||
if key_match:
|
||||
current_key = key_match.group(1)
|
||||
value = key_match.group(2).strip()
|
||||
|
||||
if value:
|
||||
metadata[current_key] = value
|
||||
current_list = None
|
||||
else:
|
||||
current_list = []
|
||||
metadata[current_key] = current_list
|
||||
elif line.startswith(" - ") and current_list is not None:
|
||||
current_list.append(line[4:].strip())
|
||||
|
||||
name = metadata.get("name", file_path.stem)
|
||||
description = metadata.get("description", "")
|
||||
triggers = metadata.get("triggers", [])
|
||||
anti_triggers = metadata.get("anti_triggers", [])
|
||||
tools = metadata.get("tools", [])
|
||||
|
||||
if isinstance(triggers, str):
|
||||
triggers = [triggers]
|
||||
if isinstance(anti_triggers, str):
|
||||
anti_triggers = [anti_triggers]
|
||||
if isinstance(tools, str):
|
||||
tools = [tools]
|
||||
|
||||
return ToolGuideDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
triggers=triggers,
|
||||
anti_triggers=anti_triggers,
|
||||
tools=tools,
|
||||
content=body.strip(),
|
||||
raw_markdown=content,
|
||||
)
|
||||
|
||||
def get_tool_guide(self, name: str) -> ToolGuideDefinition | None:
|
||||
"""Get full tool guide definition by name."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return self._tool_guides.get(name)
|
||||
|
||||
def get_tool_metadata(self, name: str) -> ToolGuideMetadata | None:
|
||||
"""Get lightweight tool guide metadata by name."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return self._metadata_cache.get(name)
|
||||
|
||||
def get_guide_for_tool(self, tool_name: str) -> ToolGuideDefinition | None:
|
||||
"""Get tool guide associated with a tool."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
guide_name = self._tool_to_guide.get(tool_name)
|
||||
if guide_name:
|
||||
return self._tool_guides.get(guide_name)
|
||||
return None
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all tool guide names."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return list(self._tool_guides.keys())
|
||||
|
||||
def list_tool_metadata(self) -> list[ToolGuideMetadata]:
|
||||
"""List all tool guide metadata (lightweight)."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return list(self._metadata_cache.values())
|
||||
|
||||
def build_tools_prompt_section(self, tool_names: list[str] | None = None) -> str:
|
||||
"""
|
||||
Build tools section for ReAct prompt.
|
||||
|
||||
Args:
|
||||
tool_names: If provided, only include tools for these names.
|
||||
If None, include all tools.
|
||||
|
||||
Returns:
|
||||
Formatted tools section string
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
|
||||
if not self._tool_guides:
|
||||
return ""
|
||||
|
||||
tools_to_include: list[ToolGuideDefinition] = []
|
||||
|
||||
if tool_names:
|
||||
for tool_name in tool_names:
|
||||
tool_guide = self.get_guide_for_tool(tool_name)
|
||||
if tool_guide and tool_guide not in tools_to_include:
|
||||
tools_to_include.append(tool_guide)
|
||||
else:
|
||||
tools_to_include = list(self._tool_guides.values())
|
||||
|
||||
if not tools_to_include:
|
||||
return ""
|
||||
|
||||
lines = ["## 工具使用指南", ""]
|
||||
lines.append("以下是每个工具的详细使用说明:")
|
||||
lines.append("")
|
||||
|
||||
for tool_guide in tools_to_include:
|
||||
lines.append(f"### {tool_guide.name}")
|
||||
lines.append(f"描述: {tool_guide.description}")
|
||||
|
||||
if tool_guide.triggers:
|
||||
lines.append("触发条件:")
|
||||
for trigger in tool_guide.triggers:
|
||||
lines.append(f" - {trigger}")
|
||||
|
||||
if tool_guide.anti_triggers:
|
||||
lines.append("不应触发:")
|
||||
for anti in tool_guide.anti_triggers:
|
||||
lines.append(f" - {anti}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(tool_guide.content)
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_compact_tools_section(self) -> str:
|
||||
"""
|
||||
Build compact tools section with only name and description.
|
||||
This is used for the initial tool list, with full guidance loaded separately.
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
|
||||
if not self._metadata_cache:
|
||||
return "当前没有可用的工具使用指南。"
|
||||
|
||||
lines = ["## 可用工具列表", "", "以下是你可以使用的工具:", ""]
|
||||
|
||||
for meta in self._metadata_cache.values():
|
||||
lines.append(f"- **{meta.name}**: {meta.description}")
|
||||
if meta.tools:
|
||||
lines.append(f" 关联工具: {', '.join(meta.tools)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
_tool_guide_registry: ToolGuideRegistry | None = None
|
||||
|
||||
|
||||
def get_tool_guide_registry() -> ToolGuideRegistry:
|
||||
"""Get global tool guide registry instance."""
|
||||
global _tool_guide_registry
|
||||
if _tool_guide_registry is None:
|
||||
_tool_guide_registry = ToolGuideRegistry()
|
||||
return _tool_guide_registry
|
||||
|
||||
|
||||
def init_tool_guide_registry(tools_dir: Path | None = None) -> ToolGuideRegistry:
|
||||
"""Initialize and return tool guide registry."""
|
||||
global _tool_guide_registry
|
||||
_tool_guide_registry = ToolGuideRegistry(tools_dir=tools_dir)
|
||||
_tool_guide_registry.load_tools()
|
||||
return _tool_guide_registry
|
||||
|
|
@ -117,172 +117,124 @@ class ToolRegistry:
|
|||
self._tools[name] = tool
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
|
||||
f"version={version}, auth_required={auth_required}"
|
||||
f"[AC-IDMP-19] Registered tool: {name} v{version} "
|
||||
f"(type={tool_type.value}, auth={auth_required}, timeout={timeout_ms}ms)"
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Unregister a tool."""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_tool(self, name: str) -> ToolDefinition | None:
|
||||
"""Get tool definition by name."""
|
||||
"""Get tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tool_type: ToolType | None = None,
|
||||
enabled_only: bool = True,
|
||||
) -> list[ToolDefinition]:
|
||||
"""List registered tools, optionally filtered."""
|
||||
tools = list(self._tools.values())
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
if tool_type:
|
||||
tools = [t for t in tools if t.tool_type == tool_type]
|
||||
def get_all_tools(self) -> list[ToolDefinition]:
|
||||
"""Get all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
if enabled_only:
|
||||
tools = [t for t in tools if t.enabled]
|
||||
def is_enabled(self, name: str) -> bool:
|
||||
"""Check if tool is enabled."""
|
||||
tool = self._tools.get(name)
|
||||
return tool.enabled if tool else False
|
||||
|
||||
return tools
|
||||
|
||||
def enable_tool(self, name: str) -> bool:
|
||||
"""Enable a tool."""
|
||||
def set_enabled(self, name: str, enabled: bool) -> bool:
|
||||
"""Enable or disable a tool."""
|
||||
tool = self._tools.get(name)
|
||||
if tool:
|
||||
tool.enabled = True
|
||||
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_tool(self, name: str) -> bool:
|
||||
"""Disable a tool."""
|
||||
tool = self._tools.get(name)
|
||||
if tool:
|
||||
tool.enabled = False
|
||||
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
|
||||
tool.enabled = enabled
|
||||
logger.info(f"[AC-IDMP-19] Tool {name} {'enabled' if enabled else 'disabled'}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
auth_context: dict[str, Any] | None = None,
|
||||
name: str,
|
||||
**kwargs: Any,
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
[AC-IDMP-19] Execute a tool with governance.
|
||||
Execute a tool with governance.
|
||||
|
||||
Args:
|
||||
tool_name: Tool name to execute
|
||||
args: Tool arguments
|
||||
auth_context: Authentication context
|
||||
name: Tool name
|
||||
**kwargs: Tool arguments
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult with output and metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
tool = self._tools.get(tool_name)
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool not found: {tool_name}",
|
||||
duration_ms=0,
|
||||
error=f"Tool not found: {name}",
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
if not tool.enabled:
|
||||
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool disabled: {tool_name}",
|
||||
duration_ms=0,
|
||||
registry_version=tool.version,
|
||||
error=f"Tool is disabled: {name}",
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
auth_applied = False
|
||||
if tool.auth_required:
|
||||
if not auth_context:
|
||||
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error="Authentication required",
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
auth_applied=False,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
auth_applied = True
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout_seconds = tool.timeout_ms / 1000.0
|
||||
if not tool.handler:
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool has no handler: {name}",
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
tool.handler(**args) if tool.handler else asyncio.sleep(0),
|
||||
timeout=timeout_seconds,
|
||||
result = await self._timeout_governor.execute_with_timeout(
|
||||
lambda: tool.handler(**kwargs),
|
||||
timeout_ms=tool.timeout_ms,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
|
||||
f"duration_ms={duration_ms}, success=True"
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
success=True,
|
||||
output=result,
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
auth_applied=tool.auth_required,
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
|
||||
f"duration_ms={duration_ms}"
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool timeout after {tool.timeout_ms}ms",
|
||||
error=f"Tool execution timeout after {tool.timeout_ms}ms",
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
|
||||
)
|
||||
logger.error(f"[AC-IDMP-19] Tool execution error: {name} - {e}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
def create_trace(
|
||||
def build_trace(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
result: ToolExecutionResult,
|
||||
args_digest: str | None = None,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
[AC-IDMP-19] Create ToolCallTrace from execution result.
|
||||
"""
|
||||
tool = self._tools.get(tool_name)
|
||||
"""Build a tool call trace from execution result."""
|
||||
import hashlib
|
||||
args_digest = hashlib.md5(str(args).encode()).hexdigest()[:8]
|
||||
|
||||
return ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
|
||||
tool_type=tool.tool_type if (tool := self._tools.get(tool_name)) else ToolType.INTERNAL,
|
||||
registry_version=result.registry_version,
|
||||
auth_applied=result.auth_applied,
|
||||
duration_ms=result.duration_ms,
|
||||
|
|
@ -293,6 +245,8 @@ class ToolRegistry:
|
|||
error_code=result.error if not result.success else None,
|
||||
args_digest=args_digest,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
arguments=args,
|
||||
result=result.output,
|
||||
)
|
||||
|
||||
def get_governance_report(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ RAG Optimization (rag-optimization/spec.md):
|
|||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sse_starlette.sse import ServerSentEvent
|
||||
|
|
@ -46,7 +46,6 @@ from app.services.flow.engine import FlowEngine
|
|||
from app.services.guardrail.behavior_service import BehaviorRuleService
|
||||
from app.services.guardrail.input_scanner import InputScanner
|
||||
from app.services.guardrail.output_filter import OutputFilter
|
||||
from app.services.guardrail.word_service import ForbiddenWordService
|
||||
from app.services.intent.router import IntentRouter
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
|
||||
|
|
@ -90,6 +89,8 @@ class GenerationContext:
|
|||
10. confidence_result: Confidence calculation result
|
||||
11. messages_saved: Whether messages were saved
|
||||
12. final_response: Final ChatResponse
|
||||
|
||||
[v0.8.0] Extended with route_trace for hybrid routing observability.
|
||||
"""
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
|
|
@ -115,6 +116,11 @@ class GenerationContext:
|
|||
target_kb_ids: list[str] | None = None
|
||||
behavior_rules: list[str] = field(default_factory=list)
|
||||
|
||||
# [v0.8.0] Hybrid routing fields
|
||||
route_trace: dict[str, Any] | None = None
|
||||
fusion_confidence: float | None = None
|
||||
fusion_decision_reason: str | None = None
|
||||
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
execution_steps: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
|
@ -487,7 +493,7 @@ class OrchestratorService:
|
|||
finish_reason="flow_step",
|
||||
)
|
||||
ctx.diagnostics["flow_handled"] = True
|
||||
logger.info(f"[AC-AISVC-75] Flow provided reply, skipping LLM")
|
||||
logger.info("[AC-AISVC-75] Flow provided reply, skipping LLM")
|
||||
|
||||
else:
|
||||
ctx.diagnostics["flow_check_enabled"] = True
|
||||
|
|
@ -501,8 +507,8 @@ class OrchestratorService:
|
|||
"""
|
||||
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
|
||||
Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
|
||||
[v0.8.0] Upgraded to use match_hybrid() for hybrid routing.
|
||||
"""
|
||||
# Skip if flow already handled the request
|
||||
if ctx.diagnostics.get("flow_handled"):
|
||||
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
|
||||
return
|
||||
|
|
@ -513,7 +519,6 @@ class OrchestratorService:
|
|||
return
|
||||
|
||||
try:
|
||||
# Load enabled rules ordered by priority
|
||||
async with get_session() as session:
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
rule_service = IntentRuleService(session)
|
||||
|
|
@ -524,33 +529,64 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_matched"] = False
|
||||
return
|
||||
|
||||
# Match intent
|
||||
ctx.intent_match = self._intent_router.match(
|
||||
fusion_result = await self._intent_router.match_hybrid(
|
||||
message=ctx.current_message,
|
||||
rules=rules,
|
||||
tenant_id=ctx.tenant_id,
|
||||
)
|
||||
|
||||
if ctx.intent_match:
|
||||
ctx.route_trace = fusion_result.trace.to_dict()
|
||||
ctx.fusion_confidence = fusion_result.final_confidence
|
||||
ctx.fusion_decision_reason = fusion_result.decision_reason
|
||||
|
||||
if fusion_result.final_intent:
|
||||
ctx.intent_match = type(
|
||||
"IntentMatchResult",
|
||||
(),
|
||||
{
|
||||
"rule": fusion_result.final_intent,
|
||||
"match_type": fusion_result.decision_reason,
|
||||
"matched": "",
|
||||
"to_dict": lambda: {
|
||||
"rule_id": str(fusion_result.final_intent.id),
|
||||
"rule_name": fusion_result.final_intent.name,
|
||||
"match_type": fusion_result.decision_reason,
|
||||
"matched": "",
|
||||
"response_type": fusion_result.final_intent.response_type,
|
||||
"target_kb_ids": (
|
||||
fusion_result.final_intent.target_kb_ids or []
|
||||
),
|
||||
"flow_id": (
|
||||
str(fusion_result.final_intent.flow_id)
|
||||
if fusion_result.final_intent.flow_id else None
|
||||
),
|
||||
"fixed_reply": fusion_result.final_intent.fixed_reply,
|
||||
"transfer_message": fusion_result.final_intent.transfer_message,
|
||||
},
|
||||
},
|
||||
)()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-69] Intent matched: rule={ctx.intent_match.rule.name}, "
|
||||
f"response_type={ctx.intent_match.rule.response_type}"
|
||||
f"[AC-AISVC-69] Intent matched: rule={fusion_result.final_intent.name}, "
|
||||
f"response_type={fusion_result.final_intent.response_type}, "
|
||||
f"decision={fusion_result.decision_reason}, "
|
||||
f"confidence={fusion_result.final_confidence:.3f}"
|
||||
)
|
||||
|
||||
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
|
||||
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
|
||||
|
||||
# Increment hit count
|
||||
async with get_session() as session:
|
||||
rule_service = IntentRuleService(session)
|
||||
await rule_service.increment_hit_count(
|
||||
tenant_id=ctx.tenant_id,
|
||||
rule_id=ctx.intent_match.rule.id,
|
||||
rule_id=fusion_result.final_intent.id,
|
||||
)
|
||||
|
||||
# Route based on response_type
|
||||
if ctx.intent_match.rule.response_type == "fixed":
|
||||
# Fixed reply - skip LLM
|
||||
rule = fusion_result.final_intent
|
||||
if rule.response_type == "fixed":
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=ctx.intent_match.rule.fixed_reply or "收到您的消息。",
|
||||
content=rule.fixed_reply or "收到您的消息。",
|
||||
model="intent_fixed",
|
||||
usage={},
|
||||
finish_reason="intent_fixed",
|
||||
|
|
@ -558,20 +594,18 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_handled"] = True
|
||||
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
|
||||
|
||||
elif ctx.intent_match.rule.response_type == "rag":
|
||||
# RAG with target KBs
|
||||
ctx.target_kb_ids = ctx.intent_match.rule.target_kb_ids or []
|
||||
elif rule.response_type == "rag":
|
||||
ctx.target_kb_ids = rule.target_kb_ids or []
|
||||
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
|
||||
|
||||
elif ctx.intent_match.rule.response_type == "flow":
|
||||
# Start script flow
|
||||
if ctx.intent_match.rule.flow_id and self._flow_engine:
|
||||
elif rule.response_type == "flow":
|
||||
if rule.flow_id and self._flow_engine:
|
||||
async with get_session() as session:
|
||||
flow_engine = FlowEngine(session)
|
||||
instance, first_step = await flow_engine.start(
|
||||
tenant_id=ctx.tenant_id,
|
||||
session_id=ctx.session_id,
|
||||
flow_id=ctx.intent_match.rule.flow_id,
|
||||
flow_id=rule.flow_id,
|
||||
)
|
||||
if first_step:
|
||||
ctx.llm_response = LLMResponse(
|
||||
|
|
@ -583,10 +617,9 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_handled"] = True
|
||||
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
|
||||
|
||||
elif ctx.intent_match.rule.response_type == "transfer":
|
||||
# Transfer to human
|
||||
elif rule.response_type == "transfer":
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=ctx.intent_match.rule.transfer_message or "正在为您转接人工客服...",
|
||||
content=rule.transfer_message or "正在为您转接人工客服...",
|
||||
model="intent_transfer",
|
||||
usage={},
|
||||
finish_reason="intent_transfer",
|
||||
|
|
@ -600,9 +633,25 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_handled"] = True
|
||||
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
|
||||
|
||||
if fusion_result.need_clarify:
|
||||
ctx.diagnostics["need_clarify"] = True
|
||||
ctx.diagnostics["clarify_candidates"] = [
|
||||
{"id": str(r.id), "name": r.name}
|
||||
for r in (fusion_result.clarify_candidates or [])
|
||||
]
|
||||
logger.info(
|
||||
f"[AC-AISVC-121] Low confidence, need clarify: "
|
||||
f"confidence={fusion_result.final_confidence:.3f}, "
|
||||
f"candidates={len(fusion_result.clarify_candidates or [])}"
|
||||
)
|
||||
|
||||
else:
|
||||
ctx.diagnostics["intent_match_enabled"] = True
|
||||
ctx.diagnostics["intent_matched"] = False
|
||||
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
|
||||
logger.info(
|
||||
f"[AC-AISVC-69] No intent matched, decision={fusion_result.decision_reason}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
|
||||
|
|
@ -1122,6 +1171,7 @@ class OrchestratorService:
|
|||
[AC-AISVC-02] Build final ChatResponse from generation context.
|
||||
Step 12 of the 12-step pipeline.
|
||||
Uses filtered_reply from Step 9.
|
||||
[v0.8.0] Includes route_trace in response metadata.
|
||||
"""
|
||||
# Use filtered_reply if available, otherwise use llm_response.content
|
||||
if ctx.filtered_reply:
|
||||
|
|
@ -1142,6 +1192,10 @@ class OrchestratorService:
|
|||
"execution_steps": ctx.execution_steps,
|
||||
}
|
||||
|
||||
# [v0.8.0] Include route_trace in response metadata
|
||||
if ctx.route_trace:
|
||||
response_metadata["route_trace"] = ctx.route_trace
|
||||
|
||||
return ChatResponse(
|
||||
reply=reply,
|
||||
confidence=confidence,
|
||||
|
|
|
|||
|
|
@ -178,6 +178,9 @@ class PromptTemplateService:
|
|||
current_version = v
|
||||
break
|
||||
|
||||
# Get latest version for current_content (not just published)
|
||||
latest_version = versions[0] if versions else None
|
||||
|
||||
return {
|
||||
"id": str(template.id),
|
||||
"name": template.name,
|
||||
|
|
@ -185,6 +188,8 @@ class PromptTemplateService:
|
|||
"description": template.description,
|
||||
"is_default": template.is_default,
|
||||
"metadata": template.metadata_,
|
||||
"current_content": latest_version.system_instruction if latest_version else None,
|
||||
"variables": latest_version.variables if latest_version else [],
|
||||
"current_version": {
|
||||
"version": current_version.version,
|
||||
"status": current_version.status,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
Retrieval module for AI Service.
|
||||
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
|
||||
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
|
||||
[AC-AISVC-RES-01~15] Strategy routing and mode routing for retrieval pipeline.
|
||||
"""
|
||||
|
||||
from app.services.retrieval.base import (
|
||||
|
|
@ -32,6 +33,29 @@ from app.services.retrieval.optimized_retriever import (
|
|||
get_optimized_retriever,
|
||||
)
|
||||
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
|
||||
from app.services.retrieval.routing_config import (
|
||||
RagRuntimeMode,
|
||||
RoutingConfig,
|
||||
StrategyContext,
|
||||
StrategyType,
|
||||
StrategyResult,
|
||||
)
|
||||
from app.services.retrieval.strategy_router import (
|
||||
RollbackRecord,
|
||||
StrategyRouter,
|
||||
get_strategy_router,
|
||||
)
|
||||
from app.services.retrieval.mode_router import (
|
||||
ComplexityAnalyzer,
|
||||
ModeRouter,
|
||||
ModeRouteResult,
|
||||
get_mode_router,
|
||||
)
|
||||
from app.services.retrieval.strategy_integration import (
|
||||
RetrievalStrategyIntegration,
|
||||
RetrievalStrategyResult,
|
||||
get_retrieval_strategy_integration,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseRetriever",
|
||||
|
|
@ -55,4 +79,19 @@ __all__ = [
|
|||
"get_knowledge_indexer",
|
||||
"IndexingProgress",
|
||||
"IndexingResult",
|
||||
"RagRuntimeMode",
|
||||
"RoutingConfig",
|
||||
"StrategyContext",
|
||||
"StrategyType",
|
||||
"StrategyResult",
|
||||
"RollbackRecord",
|
||||
"StrategyRouter",
|
||||
"get_strategy_router",
|
||||
"ComplexityAnalyzer",
|
||||
"ModeRouter",
|
||||
"ModeRouteResult",
|
||||
"get_mode_router",
|
||||
"RetrievalStrategyIntegration",
|
||||
"RetrievalStrategyResult",
|
||||
"get_retrieval_strategy_integration",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class RetrievalContext:
|
|||
metadata: dict[str, Any] | None = None
|
||||
tag_filter: "TagFilter | None" = None
|
||||
kb_ids: list[str] | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
|
||||
"""获取标签过滤器的字典表示"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,438 @@
|
|||
"""
|
||||
Mode Router for RAG Runtime Mode Selection.
|
||||
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
|
||||
|
||||
Mode Descriptions:
|
||||
- direct: Low-latency generic retrieval path (single KB call)
|
||||
- react: Multi-step ReAct retrieval path (high accuracy)
|
||||
- auto: Automatic selection based on complexity/confidence rules
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.services.retrieval.routing_config import (
|
||||
RagRuntimeMode,
|
||||
RoutingConfig,
|
||||
StrategyContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.retrieval.base import RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplexityAnalyzer:
|
||||
"""
|
||||
Analyzes query complexity for mode routing decisions.
|
||||
|
||||
Complexity factors:
|
||||
- Query length
|
||||
- Number of conditions/constraints
|
||||
- Presence of logical operators (and, or, not)
|
||||
- Cross-domain indicators
|
||||
- Multi-step reasoning requirements
|
||||
"""
|
||||
|
||||
short_query_threshold: int = 20
|
||||
long_query_threshold: int = 100
|
||||
|
||||
condition_patterns: list[str] = field(default_factory=lambda: [
|
||||
r"和|与|及|并且|同时",
|
||||
r"或者|还是|要么",
|
||||
r"但是|不过|然而",
|
||||
r"如果|假如|假设",
|
||||
r"既.*又",
|
||||
r"不仅.*而且",
|
||||
])
|
||||
|
||||
reasoning_patterns: list[str] = field(default_factory=lambda: [
|
||||
r"为什么|原因|理由",
|
||||
r"怎么|如何|怎样",
|
||||
r"区别|差异|不同",
|
||||
r"比较|对比|优劣",
|
||||
r"分析|评估|判断",
|
||||
])
|
||||
|
||||
cross_domain_patterns: list[str] = field(default_factory=lambda: [
|
||||
r"跨|多|各个",
|
||||
r"所有|全部|整体",
|
||||
r"综合|汇总|统计",
|
||||
])
|
||||
|
||||
def analyze(self, query: str) -> float:
|
||||
"""
|
||||
Analyze query complexity and return a score (0.0 ~ 1.0).
|
||||
|
||||
Higher score indicates more complex query that may benefit from ReAct mode.
|
||||
|
||||
Args:
|
||||
query: User query text
|
||||
|
||||
Returns:
|
||||
Complexity score (0.0 = simple, 1.0 = very complex)
|
||||
"""
|
||||
if not query:
|
||||
return 0.0
|
||||
|
||||
score = 0.0
|
||||
|
||||
query_length = len(query)
|
||||
if query_length < self.short_query_threshold:
|
||||
score += 0.0
|
||||
elif query_length > self.long_query_threshold:
|
||||
score += 0.3
|
||||
else:
|
||||
score += 0.15
|
||||
|
||||
condition_count = 0
|
||||
for pattern in self.condition_patterns:
|
||||
matches = re.findall(pattern, query)
|
||||
condition_count += len(matches)
|
||||
|
||||
if condition_count >= 3:
|
||||
score += 0.3
|
||||
elif condition_count >= 2:
|
||||
score += 0.2
|
||||
elif condition_count >= 1:
|
||||
score += 0.1
|
||||
|
||||
for pattern in self.reasoning_patterns:
|
||||
if re.search(pattern, query):
|
||||
score += 0.15
|
||||
break
|
||||
|
||||
for pattern in self.cross_domain_patterns:
|
||||
if re.search(pattern, query):
|
||||
score += 0.15
|
||||
break
|
||||
|
||||
question_marks = query.count("?") + query.count("?")
|
||||
if question_marks >= 2:
|
||||
score += 0.1
|
||||
|
||||
return min(1.0, score)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModeRouteResult:
|
||||
"""Result from mode routing decision."""
|
||||
mode: RagRuntimeMode
|
||||
confidence: float
|
||||
complexity_score: float
|
||||
should_fallback_to_react: bool = False
|
||||
fallback_reason: str | None = None
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DirectRetrievalExecutor:
|
||||
"""
|
||||
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
|
||||
|
||||
Single KB call without multi-step reasoning.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._retriever = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> "RetrievalResult":
|
||||
"""
|
||||
Execute direct retrieval (single KB call).
|
||||
"""
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
if self._retriever is None:
|
||||
self._retriever = await get_optimized_retriever()
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
metadata_filter=ctx.metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
return await self._retriever.retrieve(retrieval_ctx)
|
||||
|
||||
|
||||
class ReactRetrievalExecutor:
|
||||
"""
|
||||
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
|
||||
|
||||
Uses AgentOrchestrator for multi-step reasoning and KB calls.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int = 5):
|
||||
self._max_steps = max_steps
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
config: RoutingConfig,
|
||||
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
|
||||
"""
|
||||
Execute ReAct retrieval (multi-step reasoning).
|
||||
|
||||
Returns:
|
||||
Tuple of (final_answer, retrieval_result, react_context)
|
||||
"""
|
||||
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
|
||||
from app.services.mid.tool_registry import ToolRegistry
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
|
||||
try:
|
||||
llm_manager = get_llm_config_manager()
|
||||
llm_client = llm_manager.get_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
|
||||
llm_client = None
|
||||
|
||||
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
|
||||
timeout_governor = TimeoutGovernor()
|
||||
|
||||
orchestrator = AgentOrchestrator(
|
||||
max_iterations=min(config.react_max_steps, self._max_steps),
|
||||
timeout_governor=timeout_governor,
|
||||
llm_client=llm_client,
|
||||
tool_registry=tool_registry,
|
||||
tenant_id=ctx.tenant_id,
|
||||
mode=AgentMode.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
base_context = {
|
||||
"query": ctx.query,
|
||||
"metadata_filter": ctx.metadata_filter,
|
||||
"kb_ids": ctx.kb_ids,
|
||||
**ctx.additional_context,
|
||||
}
|
||||
|
||||
final_answer, react_ctx, trace = await orchestrator.execute(
|
||||
user_message=ctx.query,
|
||||
context=base_context,
|
||||
)
|
||||
|
||||
return final_answer, None, {
|
||||
"iterations": react_ctx.iteration,
|
||||
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
|
||||
"final_answer": final_answer,
|
||||
}
|
||||
|
||||
|
||||
class ModeRouter:
|
||||
"""
|
||||
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
|
||||
Mode router for RAG runtime mode selection.
|
||||
|
||||
Mode Selection:
|
||||
- direct: Low-latency generic retrieval (single KB call)
|
||||
- react: Multi-step ReAct retrieval (high accuracy)
|
||||
- auto: Automatic selection based on complexity/confidence
|
||||
|
||||
Auto Mode Rules:
|
||||
- Direct conditions:
|
||||
- Short query, clear intent
|
||||
- High metadata confidence
|
||||
- No cross-domain/multi-condition
|
||||
- React conditions:
|
||||
- Multi-condition/multi-constraint
|
||||
- Low metadata confidence
|
||||
- Need for secondary confirmation or multi-step reasoning
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RoutingConfig | None = None,
|
||||
):
|
||||
self._config = config or RoutingConfig()
|
||||
self._complexity_analyzer = ComplexityAnalyzer()
|
||||
self._direct_executor = DirectRetrievalExecutor()
|
||||
self._react_executor = ReactRetrievalExecutor(
|
||||
max_steps=self._config.react_max_steps
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self) -> RoutingConfig:
|
||||
"""Get current configuration."""
|
||||
return self._config
|
||||
|
||||
def update_config(self, new_config: RoutingConfig) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-15] Update routing configuration (hot reload).
|
||||
"""
|
||||
self._config = new_config
|
||||
self._react_executor._max_steps = new_config.react_max_steps
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-15] ModeRouter config updated: "
|
||||
f"mode={new_config.rag_runtime_mode.value}, "
|
||||
f"react_max_steps={new_config.react_max_steps}, "
|
||||
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
|
||||
)
|
||||
|
||||
def route(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> ModeRouteResult:
|
||||
"""
|
||||
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
|
||||
Route to appropriate mode based on configuration and context.
|
||||
|
||||
Args:
|
||||
ctx: Strategy context with query, metadata, confidence, etc.
|
||||
|
||||
Returns:
|
||||
ModeRouteResult with selected mode and diagnostics
|
||||
"""
|
||||
configured_mode = self._config.get_rag_runtime_mode()
|
||||
|
||||
if configured_mode == RagRuntimeMode.DIRECT:
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
|
||||
)
|
||||
return ModeRouteResult(
|
||||
mode=RagRuntimeMode.DIRECT,
|
||||
confidence=ctx.metadata_confidence,
|
||||
complexity_score=ctx.complexity_score,
|
||||
diagnostics={"configured_mode": "direct"},
|
||||
)
|
||||
|
||||
if configured_mode == RagRuntimeMode.REACT:
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
|
||||
)
|
||||
return ModeRouteResult(
|
||||
mode=RagRuntimeMode.REACT,
|
||||
confidence=ctx.metadata_confidence,
|
||||
complexity_score=ctx.complexity_score,
|
||||
diagnostics={"configured_mode": "react"},
|
||||
)
|
||||
|
||||
complexity_score = self._complexity_analyzer.analyze(ctx.query)
|
||||
effective_complexity = max(complexity_score, ctx.complexity_score)
|
||||
|
||||
should_use_react = self._config.should_trigger_react_in_auto_mode(
|
||||
confidence=ctx.metadata_confidence,
|
||||
complexity_score=effective_complexity,
|
||||
)
|
||||
|
||||
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
|
||||
f"Auto mode routing: selected={selected_mode.value}, "
|
||||
f"confidence={ctx.metadata_confidence:.2f}, "
|
||||
f"complexity={effective_complexity:.2f}, "
|
||||
f"tenant={ctx.tenant_id}"
|
||||
)
|
||||
|
||||
return ModeRouteResult(
|
||||
mode=selected_mode,
|
||||
confidence=ctx.metadata_confidence,
|
||||
complexity_score=effective_complexity,
|
||||
diagnostics={
|
||||
"configured_mode": "auto",
|
||||
"analyzed_complexity": complexity_score,
|
||||
"provided_complexity": ctx.complexity_score,
|
||||
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
|
||||
"react_trigger_complexity": self._config.react_trigger_complexity_score,
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_direct(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> "RetrievalResult":
|
||||
"""
|
||||
Execute direct retrieval mode.
|
||||
"""
|
||||
return await self._direct_executor.execute(ctx)
|
||||
|
||||
async def execute_react(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
|
||||
"""
|
||||
Execute ReAct retrieval mode.
|
||||
"""
|
||||
return await self._react_executor.execute(ctx, self._config)
|
||||
|
||||
async def execute_with_fallback(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
|
||||
"""
|
||||
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
|
||||
|
||||
Args:
|
||||
ctx: Strategy context
|
||||
|
||||
Returns:
|
||||
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
|
||||
"""
|
||||
route_result = self.route(ctx)
|
||||
|
||||
if route_result.mode == RagRuntimeMode.DIRECT:
|
||||
retrieval_result = await self._direct_executor.execute(ctx)
|
||||
|
||||
max_score = 0.0
|
||||
if retrieval_result and retrieval_result.hits:
|
||||
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
|
||||
|
||||
if self._config.should_fallback_direct_to_react(max_score):
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
|
||||
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
|
||||
)
|
||||
|
||||
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
|
||||
|
||||
return (
|
||||
None,
|
||||
final_answer,
|
||||
ModeRouteResult(
|
||||
mode=RagRuntimeMode.REACT,
|
||||
confidence=max_score,
|
||||
complexity_score=route_result.complexity_score,
|
||||
should_fallback_to_react=True,
|
||||
fallback_reason="low_confidence",
|
||||
diagnostics={
|
||||
**route_result.diagnostics,
|
||||
"fallback_from": "direct",
|
||||
"direct_confidence": max_score,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return retrieval_result, None, route_result
|
||||
|
||||
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
|
||||
|
||||
return None, final_answer, route_result
|
||||
|
||||
|
||||
_mode_router: ModeRouter | None = None
|
||||
|
||||
|
||||
def get_mode_router() -> ModeRouter:
|
||||
"""Get or create ModeRouter singleton."""
|
||||
global _mode_router
|
||||
if _mode_router is None:
|
||||
_mode_router = ModeRouter()
|
||||
return _mode_router
|
||||
|
||||
|
||||
def reset_mode_router() -> None:
|
||||
"""Reset ModeRouter singleton (for testing)."""
|
||||
global _mode_router
|
||||
_mode_router = None
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
"""
|
||||
Retrieval and Embedding Strategy Configuration.
|
||||
[AC-AISVC-RES-01~15] Configuration for strategy routing and mode routing.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StrategyType(str, Enum):
|
||||
"""Strategy type for retrieval pipeline selection."""
|
||||
DEFAULT = "default"
|
||||
ENHANCED = "enhanced"
|
||||
|
||||
|
||||
class RagRuntimeMode(str, Enum):
|
||||
"""RAG runtime mode for execution path selection."""
|
||||
DIRECT = "direct"
|
||||
REACT = "react"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingConfig:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Routing configuration for strategy and mode selection.
|
||||
|
||||
Configuration hierarchy:
|
||||
1. Strategy selection (default vs enhanced)
|
||||
2. Mode selection (direct/react/auto)
|
||||
3. Auto routing rules (complexity/confidence thresholds)
|
||||
4. Fallback behavior
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
strategy: StrategyType = StrategyType.DEFAULT
|
||||
|
||||
grayscale_percentage: float = 0.0
|
||||
grayscale_allowlist: list[str] = field(default_factory=list)
|
||||
|
||||
rag_runtime_mode: RagRuntimeMode = RagRuntimeMode.AUTO
|
||||
|
||||
react_trigger_confidence_threshold: float = 0.6
|
||||
react_trigger_complexity_score: float = 0.5
|
||||
react_max_steps: int = 5
|
||||
|
||||
direct_fallback_on_low_confidence: bool = True
|
||||
direct_fallback_confidence_threshold: float = 0.4
|
||||
|
||||
performance_budget_ms: int = 5000
|
||||
performance_degradation_threshold: float = 0.2
|
||||
|
||||
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
|
||||
"""
|
||||
[AC-AISVC-RES-02, AC-AISVC-RES-03] Determine if enhanced strategy should be used.
|
||||
|
||||
Priority:
|
||||
1. If strategy is explicitly set to ENHANCED, use enhanced
|
||||
2. If strategy is DEFAULT, use default
|
||||
3. If grayscale is enabled, check percentage/allowlist
|
||||
"""
|
||||
if self.strategy == StrategyType.ENHANCED:
|
||||
return True
|
||||
|
||||
if self.strategy == StrategyType.DEFAULT:
|
||||
return False
|
||||
|
||||
if self.grayscale_percentage > 0:
|
||||
import hashlib
|
||||
if tenant_id:
|
||||
hash_val = int(hashlib.md5(tenant_id.encode()).hexdigest()[:8], 16)
|
||||
return (hash_val % 100) < (self.grayscale_percentage * 100)
|
||||
return False
|
||||
|
||||
if self.grayscale_allowlist and tenant_id:
|
||||
return tenant_id in self.grayscale_allowlist
|
||||
|
||||
return False
|
||||
|
||||
def get_rag_runtime_mode(self) -> RagRuntimeMode:
|
||||
"""Get the configured RAG runtime mode."""
|
||||
return self.rag_runtime_mode
|
||||
|
||||
def should_fallback_direct_to_react(self, confidence: float) -> bool:
|
||||
"""
|
||||
[AC-AISVC-RES-14] Determine if direct mode should fallback to react.
|
||||
|
||||
Args:
|
||||
confidence: Retrieval confidence score (0.0 ~ 1.0)
|
||||
|
||||
Returns:
|
||||
True if fallback should be triggered
|
||||
"""
|
||||
if not self.direct_fallback_on_low_confidence:
|
||||
return False
|
||||
|
||||
return confidence < self.direct_fallback_confidence_threshold
|
||||
|
||||
def should_trigger_react_in_auto_mode(
|
||||
self,
|
||||
confidence: float,
|
||||
complexity_score: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13]
|
||||
Determine if react mode should be triggered in auto mode.
|
||||
|
||||
Direct conditions (优先):
|
||||
- Short query, clear intent
|
||||
- High metadata confidence
|
||||
- No cross-domain/multi-condition
|
||||
|
||||
React conditions:
|
||||
- Multi-condition/multi-constraint
|
||||
- Low metadata confidence
|
||||
- Need for secondary confirmation or multi-step reasoning
|
||||
|
||||
Args:
|
||||
confidence: Metadata inference confidence (0.0 ~ 1.0)
|
||||
complexity_score: Query complexity score (0.0 ~ 1.0)
|
||||
|
||||
Returns:
|
||||
True if react mode should be used
|
||||
"""
|
||||
if confidence < self.react_trigger_confidence_threshold:
|
||||
return True
|
||||
|
||||
if complexity_score > self.react_trigger_complexity_score:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def validate(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
[AC-AISVC-RES-06] Validate configuration consistency.
|
||||
|
||||
Returns:
|
||||
(is_valid, list of error messages)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if self.grayscale_percentage < 0 or self.grayscale_percentage > 1.0:
|
||||
errors.append("grayscale_percentage must be between 0.0 and 1.0")
|
||||
|
||||
if self.react_trigger_confidence_threshold < 0 or self.react_trigger_confidence_threshold > 1.0:
|
||||
errors.append("react_trigger_confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
if self.react_trigger_complexity_score < 0 or self.react_trigger_complexity_score > 1.0:
|
||||
errors.append("react_trigger_complexity_score must be between 0.0 and 1.0")
|
||||
|
||||
if self.react_max_steps < 3 or self.react_max_steps > 10:
|
||||
errors.append("react_max_steps must be between 3 and 10")
|
||||
|
||||
if self.direct_fallback_confidence_threshold < 0 or self.direct_fallback_confidence_threshold > 1.0:
|
||||
errors.append("direct_fallback_confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
if self.performance_budget_ms < 1000:
|
||||
errors.append("performance_budget_ms must be at least 1000")
|
||||
|
||||
if self.performance_degradation_threshold < 0 or self.performance_degradation_threshold > 1.0:
|
||||
errors.append("performance_degradation_threshold must be between 0.0 and 1.0")
|
||||
|
||||
return (len(errors) == 0, errors)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyContext:
|
||||
"""Context for strategy routing decision."""
|
||||
tenant_id: str
|
||||
query: str
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
metadata_confidence: float = 1.0
|
||||
complexity_score: float = 0.0
|
||||
kb_ids: list[str] | None = None
|
||||
top_k: int = 5
|
||||
additional_context: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyResult:
|
||||
"""Result from strategy routing."""
|
||||
strategy: StrategyType
|
||||
mode: RagRuntimeMode
|
||||
should_fallback: bool = False
|
||||
fallback_reason: str | None = None
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Retrieval Strategy Module for AI Service.
|
||||
[AC-AISVC-RES-01~15] 策略化检索与嵌入模块。
|
||||
|
||||
核心组件:
|
||||
- RetrievalStrategyConfig: 策略配置模型
|
||||
- BasePipeline: Pipeline 抽象基类
|
||||
- DefaultPipeline: 默认策略(复用现有逻辑)
|
||||
- EnhancedPipeline: 增强策略(新端到端流程)
|
||||
- MetadataInferenceService: 元数据推断统一入口
|
||||
- StrategyRouter: 策略路由器
|
||||
- ModeRouter: 模式路由器(direct/react/auto)
|
||||
- RollbackManager: 回退管理器
|
||||
"""
|
||||
|
||||
from app.services.retrieval.strategy.config import (
|
||||
FilterMode,
|
||||
GrayscaleConfig,
|
||||
HybridRetrievalConfig,
|
||||
MetadataInferenceConfig,
|
||||
ModeRouterConfig,
|
||||
PipelineConfig,
|
||||
RerankerConfig,
|
||||
RetrievalStrategyConfig,
|
||||
RuntimeMode,
|
||||
StrategyType,
|
||||
get_strategy_config,
|
||||
set_strategy_config,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import (
|
||||
DefaultPipeline,
|
||||
get_default_pipeline,
|
||||
)
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import (
|
||||
EnhancedPipeline,
|
||||
get_enhanced_pipeline,
|
||||
)
|
||||
from app.services.retrieval.strategy.metadata_inference import (
|
||||
InferenceContext,
|
||||
InferenceResult,
|
||||
MetadataInferenceService,
|
||||
)
|
||||
from app.services.retrieval.strategy.mode_router import (
|
||||
ModeDecision,
|
||||
ModeRouter,
|
||||
get_mode_router,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
MetadataFilterResult,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
from app.services.retrieval.strategy.rollback_manager import (
|
||||
AuditLog,
|
||||
RollbackManager,
|
||||
RollbackResult,
|
||||
RollbackTrigger,
|
||||
get_rollback_manager,
|
||||
)
|
||||
from app.services.retrieval.strategy.strategy_router import (
|
||||
RoutingDecision,
|
||||
StrategyRouter,
|
||||
get_strategy_router,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BasePipeline",
|
||||
"PipelineContext",
|
||||
"PipelineResult",
|
||||
"MetadataFilterResult",
|
||||
"DefaultPipeline",
|
||||
"get_default_pipeline",
|
||||
"EnhancedPipeline",
|
||||
"get_enhanced_pipeline",
|
||||
"RetrievalStrategyConfig",
|
||||
"GrayscaleConfig",
|
||||
"PipelineConfig",
|
||||
"RerankerConfig",
|
||||
"ModeRouterConfig",
|
||||
"HybridRetrievalConfig",
|
||||
"MetadataInferenceConfig",
|
||||
"StrategyType",
|
||||
"FilterMode",
|
||||
"RuntimeMode",
|
||||
"get_strategy_config",
|
||||
"set_strategy_config",
|
||||
"MetadataInferenceService",
|
||||
"InferenceContext",
|
||||
"InferenceResult",
|
||||
"StrategyRouter",
|
||||
"RoutingDecision",
|
||||
"get_strategy_router",
|
||||
"ModeRouter",
|
||||
"ModeDecision",
|
||||
"get_mode_router",
|
||||
"RollbackManager",
|
||||
"RollbackResult",
|
||||
"RollbackTrigger",
|
||||
"AuditLog",
|
||||
"get_rollback_manager",
|
||||
]
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
Retrieval Strategy Configuration.
|
||||
[AC-AISVC-RES-01~15] 检索策略配置模型。
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StrategyType(str, Enum):
|
||||
"""策略类型。"""
|
||||
DEFAULT = "default"
|
||||
ENHANCED = "enhanced"
|
||||
|
||||
|
||||
class RuntimeMode(str, Enum):
|
||||
"""运行时模式。"""
|
||||
DIRECT = "direct"
|
||||
REACT = "react"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class FilterMode(str, Enum):
|
||||
"""过滤模式。"""
|
||||
HARD = "hard"
|
||||
SOFT = "soft"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrayscaleConfig:
|
||||
"""灰度发布配置。【AC-AISVC-RES-03】"""
|
||||
enabled: bool = False
|
||||
percentage: float = 0.0
|
||||
allowlist: list[str] = field(default_factory=list)
|
||||
|
||||
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
|
||||
"""判断是否应该使用增强策略。"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
|
||||
return True
|
||||
|
||||
return random.random() * 100 < self.percentage
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridRetrievalConfig:
|
||||
"""混合检索配置。"""
|
||||
dense_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
rrf_k: int = 60
|
||||
enable_keyword: bool = True
|
||||
keyword_top_k_multiplier: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerConfig:
|
||||
"""重排器配置。【AC-AISVC-RES-08】"""
|
||||
enabled: bool = False
|
||||
model: str = "cross-encoder"
|
||||
top_k_after_rerank: int = 5
|
||||
min_score_threshold: float = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModeRouterConfig:
|
||||
"""模式路由配置。【AC-AISVC-RES-09~15】"""
|
||||
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
|
||||
react_trigger_confidence_threshold: float = 0.6
|
||||
react_trigger_complexity_score: float = 0.5
|
||||
react_max_steps: int = 5
|
||||
direct_fallback_on_low_confidence: bool = True
|
||||
short_query_threshold: int = 20
|
||||
|
||||
def should_use_react(
|
||||
self,
|
||||
query: str,
|
||||
confidence: float | None = None,
|
||||
complexity_score: float | None = None,
|
||||
) -> bool:
|
||||
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
|
||||
if self.runtime_mode == RuntimeMode.REACT:
|
||||
return True
|
||||
if self.runtime_mode == RuntimeMode.DIRECT:
|
||||
return False
|
||||
|
||||
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
|
||||
return False
|
||||
|
||||
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
|
||||
return True
|
||||
|
||||
if confidence and confidence < self.react_trigger_confidence_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataInferenceConfig:
|
||||
"""元数据推断配置。"""
|
||||
enabled: bool = True
|
||||
confidence_high_threshold: float = 0.8
|
||||
confidence_low_threshold: float = 0.5
|
||||
default_filter_mode: FilterMode = FilterMode.SOFT
|
||||
cache_ttl_seconds: int = 300
|
||||
|
||||
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
|
||||
"""根据置信度确定过滤模式。"""
|
||||
if confidence is None:
|
||||
return FilterMode.NONE
|
||||
if confidence >= self.confidence_high_threshold:
|
||||
return FilterMode.HARD
|
||||
if confidence >= self.confidence_low_threshold:
|
||||
return FilterMode.SOFT
|
||||
return FilterMode.NONE
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig:
|
||||
"""Pipeline 配置。"""
|
||||
top_k: int = 5
|
||||
score_threshold: float = 0.01
|
||||
min_hits: int = 1
|
||||
two_stage_enabled: bool = True
|
||||
two_stage_expand_factor: int = 10
|
||||
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalStrategyConfig:
|
||||
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
|
||||
active_strategy: StrategyType = StrategyType.DEFAULT
|
||||
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
|
||||
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
|
||||
reranker: RerankerConfig = field(default_factory=RerankerConfig)
|
||||
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
|
||||
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
|
||||
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
|
||||
"max_latency_ms": 2000.0,
|
||||
"min_success_rate": 0.95,
|
||||
"max_error_rate": 0.05,
|
||||
})
|
||||
|
||||
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
|
||||
"""判断是否启用增强策略。"""
|
||||
if self.active_strategy == StrategyType.ENHANCED:
|
||||
return True
|
||||
return self.grayscale.should_use_enhanced(tenant_id, user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典。"""
|
||||
return {
|
||||
"active_strategy": self.active_strategy.value,
|
||||
"grayscale": {
|
||||
"enabled": self.grayscale.enabled,
|
||||
"percentage": self.grayscale.percentage,
|
||||
"allowlist": self.grayscale.allowlist,
|
||||
},
|
||||
"pipeline": {
|
||||
"top_k": self.pipeline.top_k,
|
||||
"score_threshold": self.pipeline.score_threshold,
|
||||
"min_hits": self.pipeline.min_hits,
|
||||
"two_stage_enabled": self.pipeline.two_stage_enabled,
|
||||
},
|
||||
"reranker": {
|
||||
"enabled": self.reranker.enabled,
|
||||
"model": self.reranker.model,
|
||||
"top_k_after_rerank": self.reranker.top_k_after_rerank,
|
||||
},
|
||||
"mode_router": {
|
||||
"runtime_mode": self.mode_router.runtime_mode.value,
|
||||
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
|
||||
},
|
||||
"metadata_inference": {
|
||||
"enabled": self.metadata_inference.enabled,
|
||||
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
|
||||
},
|
||||
"performance_thresholds": self.performance_thresholds,
|
||||
}
|
||||
|
||||
|
||||
_global_config: RetrievalStrategyConfig | None = None
|
||||
|
||||
|
||||
def get_strategy_config() -> RetrievalStrategyConfig:
|
||||
"""获取全局策略配置。"""
|
||||
global _global_config
|
||||
if _global_config is None:
|
||||
_global_config = RetrievalStrategyConfig()
|
||||
return _global_config
|
||||
|
||||
|
||||
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
|
||||
"""设置全局策略配置。"""
|
||||
global _global_config
|
||||
_global_config = config
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Default Pipeline.
|
||||
[AC-AISVC-RES-01] 默认策略 Pipeline,复用现有逻辑。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalResult
|
||||
from app.services.retrieval.optimized_retriever import OptimizedRetriever, get_optimized_retriever
|
||||
from app.services.retrieval.strategy.config import PipelineConfig
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultPipeline(BasePipeline):
|
||||
"""
|
||||
默认策略 Pipeline。【AC-AISVC-RES-01】
|
||||
|
||||
复用现有 OptimizedRetriever 逻辑,保持线上行为不变。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig | None = None,
|
||||
optimized_retriever: OptimizedRetriever | None = None,
|
||||
):
|
||||
self._config = config or PipelineConfig()
|
||||
self._optimized_retriever = optimized_retriever
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "default_pipeline"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "默认检索策略,复用现有 OptimizedRetriever 逻辑。"
|
||||
|
||||
async def _get_retriever(self) -> OptimizedRetriever:
|
||||
if self._optimized_retriever is None:
|
||||
self._optimized_retriever = await get_optimized_retriever()
|
||||
return self._optimized_retriever
|
||||
|
||||
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""执行默认检索流程。【AC-AISVC-RES-01】"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[DefaultPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
retriever = await self._get_retriever()
|
||||
|
||||
metadata_filter = None
|
||||
if ctx.metadata_filter:
|
||||
metadata_filter = ctx.metadata_filter.filter_dict
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
session_id=ctx.session_id,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(retrieval_ctx)
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"[DefaultPipeline] Retrieval completed: hits={len(result.hits)}, "
|
||||
f"latency_ms={latency_ms:.2f}"
|
||||
)
|
||||
|
||||
return PipelineResult(
|
||||
retrieval_result=result,
|
||||
pipeline_name=self.name,
|
||||
metadata_filter_applied=metadata_filter is not None,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={
|
||||
"retriever": "OptimizedRetriever",
|
||||
**(result.diagnostics or {}),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[DefaultPipeline] Retrieval error: {e}", exc_info=True)
|
||||
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
try:
|
||||
retriever = await self._get_retriever()
|
||||
return await retriever.health_check()
|
||||
except Exception as e:
|
||||
logger.error(f"[DefaultPipeline] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_default_pipeline: DefaultPipeline | None = None
|
||||
|
||||
|
||||
async def get_default_pipeline() -> DefaultPipeline:
|
||||
"""获取 DefaultPipeline 单例。"""
|
||||
global _default_pipeline
|
||||
if _default_pipeline is None:
|
||||
_default_pipeline = DefaultPipeline()
|
||||
return _default_pipeline
|
||||
|
|
@ -0,0 +1,364 @@
|
|||
"""
|
||||
Enhanced Pipeline.
|
||||
[AC-AISVC-RES-02] 增强策略 Pipeline,新端到端流程。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
from app.services.retrieval.base import RetrievalHit, RetrievalResult
|
||||
from app.services.retrieval.optimized_retriever import RRFCombiner
|
||||
from app.services.retrieval.strategy.config import (
|
||||
HybridRetrievalConfig,
|
||||
PipelineConfig,
|
||||
RerankerConfig,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
BasePipeline,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalCandidate:
|
||||
"""检索候选结果。"""
|
||||
id: str
|
||||
text: str
|
||||
score: float
|
||||
vector_score: float = 0.0
|
||||
keyword_score: float = 0.0
|
||||
metadata: dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class EnhancedPipeline(BasePipeline):
|
||||
"""
|
||||
增强策略 Pipeline。【AC-AISVC-RES-02】
|
||||
|
||||
新端到端流程:
|
||||
1. Dense 向量检索
|
||||
2. Keyword 关键词检索
|
||||
3. RRF 融合排序
|
||||
4. 可选重排
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig | None = None,
|
||||
reranker_config: RerankerConfig | None = None,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
):
|
||||
self._config = config or PipelineConfig()
|
||||
self._reranker_config = reranker_config or RerankerConfig()
|
||||
self._qdrant_client = qdrant_client
|
||||
self._rrf_combiner = RRFCombiner(k=self._config.hybrid.rrf_k)
|
||||
self._embedding_provider: NomicEmbeddingProvider | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "enhanced_pipeline"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "增强检索策略,支持 Dense + Keyword + RRF 组合检索。"
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
from app.services.embedding.factory import get_embedding_config_manager
|
||||
manager = get_embedding_config_manager()
|
||||
provider = await manager.get_provider()
|
||||
if isinstance(provider, NomicEmbeddingProvider):
|
||||
self._embedding_provider = provider
|
||||
else:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""执行增强检索流程。【AC-AISVC-RES-02】"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[EnhancedPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = await self._get_embedding_provider()
|
||||
embedding_result = await provider.embed_query(ctx.query)
|
||||
|
||||
candidates = await self._hybrid_retrieve(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
embedding_result=embedding_result,
|
||||
metadata_filter=ctx.metadata_filter.filter_dict if ctx.metadata_filter else None,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
if self._reranker_config.enabled and ctx.use_reranker:
|
||||
candidates = await self._rerank(
|
||||
candidates=candidates,
|
||||
query=ctx.query,
|
||||
)
|
||||
|
||||
top_k = self._config.top_k
|
||||
final_candidates = candidates[:top_k]
|
||||
|
||||
hits = [
|
||||
RetrievalHit(
|
||||
text=c.text,
|
||||
score=c.score,
|
||||
source=self.name,
|
||||
metadata=c.metadata,
|
||||
)
|
||||
for c in final_candidates
|
||||
if c.score >= self._config.score_threshold
|
||||
]
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"[EnhancedPipeline] Retrieval completed: hits={len(hits)}, "
|
||||
f"latency_ms={latency_ms:.2f}"
|
||||
)
|
||||
|
||||
result = RetrievalResult(
|
||||
hits=hits,
|
||||
diagnostics={
|
||||
"total_candidates": len(candidates),
|
||||
"after_rerank": self._reranker_config.enabled and ctx.use_reranker,
|
||||
},
|
||||
)
|
||||
|
||||
return PipelineResult(
|
||||
retrieval_result=result,
|
||||
pipeline_name=self.name,
|
||||
used_reranker=self._reranker_config.enabled and ctx.use_reranker,
|
||||
metadata_filter_applied=ctx.metadata_filter is not None,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={
|
||||
"dense_weight": self._config.hybrid.dense_weight,
|
||||
"keyword_weight": self._config.hybrid.keyword_weight,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[EnhancedPipeline] Retrieval error: {e}", exc_info=True)
|
||||
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
|
||||
|
||||
async def _hybrid_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
embedding_result: Any,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[RetrievalCandidate]:
|
||||
"""混合检索:Dense + Keyword + RRF。"""
|
||||
client = await self._get_client()
|
||||
top_k = self._config.top_k
|
||||
expand_factor = self._config.hybrid.keyword_top_k_multiplier
|
||||
|
||||
vector_task = self._dense_search(
|
||||
client=client,
|
||||
tenant_id=tenant_id,
|
||||
embedding=embedding_result.embedding_full,
|
||||
top_k=top_k * expand_factor,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
|
||||
keyword_task = self._keyword_search(
|
||||
client=client,
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
top_k=top_k * expand_factor,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
) if self._config.hybrid.enable_keyword else asyncio.sleep(0, result=[])
|
||||
|
||||
vector_results, keyword_results = await asyncio.gather(
|
||||
vector_task, keyword_task, return_exceptions=True
|
||||
)
|
||||
|
||||
if isinstance(vector_results, Exception):
|
||||
logger.warning(f"[EnhancedPipeline] Dense search failed: {vector_results}")
|
||||
vector_results = []
|
||||
|
||||
if isinstance(keyword_results, Exception):
|
||||
logger.warning(f"[EnhancedPipeline] Keyword search failed: {keyword_results}")
|
||||
keyword_results = []
|
||||
|
||||
combined = self._rrf_combiner.combine(
|
||||
vector_results=vector_results,
|
||||
bm25_results=keyword_results,
|
||||
vector_weight=self._config.hybrid.dense_weight,
|
||||
bm25_weight=self._config.hybrid.keyword_weight,
|
||||
)
|
||||
|
||||
candidates = []
|
||||
for item in combined:
|
||||
candidates.append(RetrievalCandidate(
|
||||
id=item.get("id", ""),
|
||||
text=item.get("payload", {}).get("text", ""),
|
||||
score=item.get("score", 0.0),
|
||||
vector_score=item.get("vector_score", 0.0),
|
||||
keyword_score=item.get("bm25_score", 0.0),
|
||||
metadata=item.get("payload", {}),
|
||||
))
|
||||
|
||||
return candidates
|
||||
|
||||
async def _dense_search(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
embedding: list[float],
|
||||
top_k: int,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Dense 向量检索。"""
|
||||
try:
|
||||
results = await client.search(
|
||||
tenant_id=tenant_id,
|
||||
query_vector=embedding,
|
||||
limit=top_k,
|
||||
vector_name="full",
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"[EnhancedPipeline] Dense search error: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keyword 关键词检索。"""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
query_terms = set(re.findall(r'\w+', query.lower()))
|
||||
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=top_k * 3,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
scored_results = []
|
||||
for point in results[0]:
|
||||
text = point.payload.get("text", "").lower()
|
||||
text_terms = set(re.findall(r'\w+', text))
|
||||
overlap = len(query_terms & text_terms)
|
||||
|
||||
if overlap > 0:
|
||||
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
||||
scored_results.append({
|
||||
"id": str(point.id),
|
||||
"score": score,
|
||||
"payload": point.payload or {},
|
||||
})
|
||||
|
||||
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored_results[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[EnhancedPipeline] Keyword search failed: {e}")
|
||||
return []
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
candidates: list[RetrievalCandidate],
|
||||
query: str,
|
||||
) -> list[RetrievalCandidate]:
|
||||
"""可选重排。"""
|
||||
if not candidates:
|
||||
return candidates
|
||||
|
||||
try:
|
||||
provider = await self._get_embedding_provider()
|
||||
query_embedding = await provider.embed_query(query)
|
||||
|
||||
reranked = []
|
||||
for candidate in candidates:
|
||||
candidate_text = candidate.text[:500]
|
||||
if candidate_text:
|
||||
candidate_embedding = await provider.embed(candidate_text)
|
||||
similarity = self._cosine_similarity(
|
||||
query_embedding.embedding_full,
|
||||
candidate_embedding,
|
||||
)
|
||||
candidate.score = similarity
|
||||
|
||||
if candidate.score >= self._reranker_config.min_score_threshold:
|
||||
reranked.append(candidate)
|
||||
|
||||
reranked.sort(key=lambda x: x.score, reverse=True)
|
||||
return reranked[:self._reranker_config.top_k_after_rerank]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[EnhancedPipeline] Rerank failed: {e}")
|
||||
return candidates
|
||||
|
||||
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度。"""
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
qdrant = await client.get_client()
|
||||
await qdrant.get_collections()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[EnhancedPipeline] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_enhanced_pipeline: EnhancedPipeline | None = None
|
||||
|
||||
|
||||
async def get_enhanced_pipeline() -> EnhancedPipeline:
|
||||
"""获取 EnhancedPipeline 单例。"""
|
||||
global _enhanced_pipeline
|
||||
if _enhanced_pipeline is None:
|
||||
_enhanced_pipeline = EnhancedPipeline()
|
||||
return _enhanced_pipeline
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Metadata Inference Service.
|
||||
[AC-AISVC-RES-04] 元数据推断统一入口。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.mid.metadata_filter_builder import (
|
||||
FilterBuildResult,
|
||||
MetadataFilterBuilder,
|
||||
)
|
||||
from app.services.retrieval.strategy.config import (
|
||||
FilterMode,
|
||||
MetadataInferenceConfig,
|
||||
)
|
||||
from app.services.retrieval.strategy.pipeline_base import MetadataFilterResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceContext:
|
||||
"""元数据推断上下文。"""
|
||||
tenant_id: str
|
||||
query: str
|
||||
session_id: str | None = None
|
||||
user_id: str | None = None
|
||||
channel_type: str | None = None
|
||||
existing_context: dict[str, Any] = field(default_factory=dict)
|
||||
slot_state: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
"""元数据推断结果。"""
|
||||
filter_result: MetadataFilterResult
|
||||
inferred_fields: dict[str, Any] = field(default_factory=dict)
|
||||
confidence_scores: dict[str, float] = field(default_factory=dict)
|
||||
overall_confidence: float | None = None
|
||||
inference_source: str = "unknown"
|
||||
|
||||
|
||||
class MetadataInferenceService:
|
||||
"""
|
||||
元数据推断统一入口。【AC-AISVC-RES-04】
|
||||
|
||||
职责:
|
||||
1. 统一的元数据推断入口(策略无关)
|
||||
2. 根据置信度决定 hard/soft filter 模式
|
||||
3. 与现有 MetadataFilterBuilder 保持一致
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: MetadataInferenceConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._config = config or MetadataInferenceConfig()
|
||||
self._filter_builder: MetadataFilterBuilder | None = None
|
||||
|
||||
async def infer(self, ctx: InferenceContext) -> InferenceResult:
|
||||
"""执行元数据推断。【AC-AISVC-RES-04】"""
|
||||
logger.info(
|
||||
f"[MetadataInference] Starting inference: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
effective_context = dict(ctx.existing_context)
|
||||
|
||||
if ctx.slot_state:
|
||||
effective_context = await self._merge_slot_state(
|
||||
effective_context, ctx.slot_state
|
||||
)
|
||||
|
||||
build_result = await self._filter_builder.build_filter(
|
||||
tenant_id=ctx.tenant_id,
|
||||
context=effective_context,
|
||||
)
|
||||
|
||||
confidence = self._calculate_confidence(build_result, effective_context)
|
||||
filter_mode = self._config.determine_filter_mode(confidence)
|
||||
|
||||
filter_result = MetadataFilterResult(
|
||||
filter_dict=build_result.applied_filter,
|
||||
filter_mode=filter_mode,
|
||||
confidence=confidence,
|
||||
missing_required_slots=build_result.missing_required_slots,
|
||||
debug_info=build_result.debug_info,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataInference] Inference completed: filter_mode={filter_mode.value}, "
|
||||
f"confidence={confidence}"
|
||||
)
|
||||
|
||||
return InferenceResult(
|
||||
filter_result=filter_result,
|
||||
inferred_fields=build_result.applied_filter,
|
||||
overall_confidence=confidence,
|
||||
inference_source="metadata_filter_builder",
|
||||
)
|
||||
|
||||
async def _merge_slot_state(
|
||||
self, context: dict[str, Any], slot_state: Any
|
||||
) -> dict[str, Any]:
|
||||
"""合并槽位状态到上下文。"""
|
||||
if hasattr(slot_state, 'filled_slots'):
|
||||
for slot_key, slot_value in slot_state.filled_slots.items():
|
||||
if slot_key not in context:
|
||||
context[slot_key] = slot_value
|
||||
return context
|
||||
|
||||
def _calculate_confidence(
|
||||
self, build_result: FilterBuildResult, context: dict[str, Any]
|
||||
) -> float | None:
|
||||
"""计算推断置信度。"""
|
||||
if build_result.missing_required_slots:
|
||||
return 0.3
|
||||
if not build_result.applied_filter:
|
||||
return None
|
||||
if not context:
|
||||
return 0.5
|
||||
applied_ratio = len(build_result.applied_filter) / max(len(context), 1)
|
||||
if applied_ratio >= 0.8:
|
||||
return 0.9
|
||||
elif applied_ratio >= 0.5:
|
||||
return 0.7
|
||||
return 0.5
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Mode Router.
|
||||
[AC-AISVC-RES-09~15] 模式路由器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.retrieval.strategy.config import ModeRouterConfig, RuntimeMode
|
||||
from app.services.retrieval.strategy.pipeline_base import PipelineResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModeDecision:
|
||||
"""模式决策结果。"""
|
||||
mode: RuntimeMode
|
||||
reason: str
|
||||
confidence: float | None = None
|
||||
complexity_score: float | None = None
|
||||
|
||||
|
||||
class ModeRouter:
|
||||
"""
|
||||
模式路由器。【AC-AISVC-RES-09~15】
|
||||
|
||||
职责:
|
||||
1. 根据 rag_runtime_mode 选择 direct/react/auto 模式
|
||||
2. auto 模式下根据复杂度与置信度自动选择路由
|
||||
3. direct 低置信度时触发 react 回退
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModeRouterConfig | None = None):
|
||||
self._config = config or ModeRouterConfig()
|
||||
|
||||
def decide(
|
||||
self,
|
||||
query: str,
|
||||
confidence: float | None = None,
|
||||
complexity_score: float | None = None,
|
||||
) -> ModeDecision:
|
||||
"""决定使用哪种模式。【AC-AISVC-RES-09~13】"""
|
||||
if self._config.runtime_mode == RuntimeMode.REACT:
|
||||
return ModeDecision(mode=RuntimeMode.REACT, reason="runtime_mode=react")
|
||||
|
||||
if self._config.runtime_mode == RuntimeMode.DIRECT:
|
||||
return ModeDecision(mode=RuntimeMode.DIRECT, reason="runtime_mode=direct")
|
||||
|
||||
calculated_complexity = complexity_score or self._calculate_complexity(query)
|
||||
|
||||
if self._should_use_direct(query, confidence, calculated_complexity):
|
||||
return ModeDecision(
|
||||
mode=RuntimeMode.DIRECT,
|
||||
reason="auto: short_query_high_confidence",
|
||||
confidence=confidence,
|
||||
complexity_score=calculated_complexity,
|
||||
)
|
||||
|
||||
return ModeDecision(
|
||||
mode=RuntimeMode.REACT,
|
||||
reason="auto: complex_or_low_confidence",
|
||||
confidence=confidence,
|
||||
complexity_score=calculated_complexity,
|
||||
)
|
||||
|
||||
def should_fallback_to_react(self, direct_result: PipelineResult) -> bool:
|
||||
"""判断是否应该从 direct 回退到 react。【AC-AISVC-RES-14】"""
|
||||
if not self._config.direct_fallback_on_low_confidence:
|
||||
return False
|
||||
if direct_result.is_empty:
|
||||
return True
|
||||
max_score = direct_result.retrieval_result.max_score
|
||||
if max_score < 0.3:
|
||||
return True
|
||||
if direct_result.retrieval_result.hit_count < 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_use_direct(
|
||||
self, query: str, confidence: float | None, complexity_score: float
|
||||
) -> bool:
|
||||
if len(query) <= self._config.short_query_threshold:
|
||||
if confidence and confidence >= self._config.react_trigger_confidence_threshold:
|
||||
return True
|
||||
if confidence and confidence < self._config.react_trigger_confidence_threshold:
|
||||
return False
|
||||
if complexity_score >= self._config.react_trigger_complexity_score:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _calculate_complexity(self, query: str) -> float:
|
||||
score = 0.0
|
||||
if len(query) > 50:
|
||||
score += 0.2
|
||||
if len(query) > 100:
|
||||
score += 0.2
|
||||
condition_words = ["和", "或", "但是", "如果", "同时", "并且", "或者", "以及"]
|
||||
for word in condition_words:
|
||||
if word in query:
|
||||
score += 0.1
|
||||
return min(score, 1.0)
|
||||
|
||||
def get_config(self) -> ModeRouterConfig:
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: ModeRouterConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
|
||||
_mode_router: ModeRouter | None = None
|
||||
|
||||
|
||||
def get_mode_router() -> ModeRouter:
|
||||
global _mode_router
|
||||
if _mode_router is None:
|
||||
_mode_router = ModeRouter()
|
||||
return _mode_router
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
Pipeline Base Classes.
|
||||
[AC-AISVC-RES-01~15] Pipeline 抽象基类。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
||||
from app.services.retrieval.strategy.config import FilterMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFilterResult:
|
||||
"""元数据过滤结果。"""
|
||||
filter_dict: dict[str, Any] = field(default_factory=dict)
|
||||
filter_mode: FilterMode = FilterMode.NONE
|
||||
confidence: float | None = None
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
debug_info: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
"""Pipeline 执行上下文。"""
|
||||
retrieval_ctx: RetrievalContext
|
||||
metadata_filter: MetadataFilterResult | None = None
|
||||
use_reranker: bool = False
|
||||
use_react: bool = False
|
||||
react_iteration: int = 0
|
||||
previous_results: list[RetrievalHit] = field(default_factory=list)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.retrieval_ctx.tenant_id
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.retrieval_ctx.query
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
return self.retrieval_ctx.session_id
|
||||
|
||||
@property
|
||||
def kb_ids(self) -> list[str] | None:
|
||||
return self.retrieval_ctx.kb_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Pipeline 执行结果。"""
|
||||
retrieval_result: RetrievalResult
|
||||
pipeline_name: str = ""
|
||||
used_reranker: bool = False
|
||||
used_react: bool = False
|
||||
react_iterations: int = 0
|
||||
metadata_filter_applied: bool = False
|
||||
fallback_triggered: bool = False
|
||||
fallback_reason: str | None = None
|
||||
latency_ms: float = 0.0
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def hits(self) -> list[RetrievalHit]:
|
||||
return self.retrieval_result.hits
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return self.retrieval_result.is_empty
|
||||
|
||||
|
||||
class BasePipeline(ABC):
|
||||
"""Pipeline 抽象基类。【AC-AISVC-RES-01, AC-AISVC-RES-02】"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Pipeline 名称。"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Pipeline 描述。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""执行检索。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
pass
|
||||
|
||||
def _create_empty_result(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
error: str | None = None,
|
||||
latency_ms: float = 0.0,
|
||||
) -> PipelineResult:
|
||||
"""创建空结果。"""
|
||||
diagnostics = {"error": error} if error else {}
|
||||
return PipelineResult(
|
||||
retrieval_result=RetrievalResult(hits=[], diagnostics=diagnostics),
|
||||
pipeline_name=self.name,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
"""
|
||||
Retrieval Strategy - Unified Entry Point.
|
||||
[AC-AISVC-RES-01~15] 检索策略统一入口。
|
||||
|
||||
整合:
|
||||
- StrategyRouter: 策略路由(default/enhanced)
|
||||
- ModeRouter: 模式路由(direct/react/auto)
|
||||
- MetadataInferenceService: 元数据推断
|
||||
- RollbackManager: 回退管理
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext, RetrievalResult
|
||||
from app.services.retrieval.strategy.config import (
|
||||
RetrievalStrategyConfig,
|
||||
RuntimeMode,
|
||||
StrategyType,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
|
||||
from app.services.retrieval.strategy.metadata_inference import (
|
||||
InferenceContext,
|
||||
MetadataInferenceService,
|
||||
)
|
||||
from app.services.retrieval.strategy.mode_router import ModeDecision, ModeRouter
|
||||
from app.services.retrieval.strategy.pipeline_base import (
|
||||
MetadataFilterResult,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
)
|
||||
from app.services.retrieval.strategy.rollback_manager import (
|
||||
RollbackManager,
|
||||
RollbackTrigger,
|
||||
)
|
||||
from app.services.retrieval.strategy.strategy_router import (
|
||||
RoutingDecision,
|
||||
StrategyRouter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalStrategyResult:
|
||||
"""检索策略执行结果。"""
|
||||
|
||||
retrieval_result: RetrievalResult
|
||||
strategy_used: StrategyType
|
||||
mode_used: RuntimeMode
|
||||
metadata_filter: MetadataFilterResult | None
|
||||
latency_ms: float
|
||||
diagnostics: dict[str, Any]
|
||||
|
||||
|
||||
class RetrievalStrategy:
|
||||
"""
|
||||
检索策略统一入口。【AC-AISVC-RES-01~15】
|
||||
|
||||
整合所有策略组件:
|
||||
1. 元数据推断(MetadataInferenceService)
|
||||
2. 策略路由(StrategyRouter)
|
||||
3. 模式路由(ModeRouter)
|
||||
4. 回退管理(RollbackManager)
|
||||
|
||||
使用方式:
|
||||
```python
|
||||
strategy = RetrievalStrategy(session)
|
||||
result = await strategy.retrieve(
|
||||
tenant_id="tenant_1",
|
||||
query="用户问题",
|
||||
context={"user_id": "user_1"},
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._config = config or RetrievalStrategyConfig()
|
||||
|
||||
self._strategy_router = StrategyRouter(self._config)
|
||||
self._mode_router = ModeRouter(self._config.mode_router)
|
||||
self._rollback_manager = RollbackManager(self._config)
|
||||
self._metadata_inference: MetadataInferenceService | None = None
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
use_reranker: bool = False,
|
||||
use_react: bool = False,
|
||||
) -> RetrievalStrategyResult:
|
||||
"""
|
||||
执行检索策略。【AC-AISVC-RES-01~15】
|
||||
|
||||
流程:
|
||||
1. 元数据推断
|
||||
2. 策略路由
|
||||
3. 模式路由
|
||||
4. 执行检索
|
||||
5. 检查是否需要回退
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
query: 查询文本
|
||||
context: 上下文信息
|
||||
kb_ids: 知识库 ID 列表
|
||||
session_id: 会话 ID
|
||||
user_id: 用户 ID
|
||||
use_reranker: 是否使用重排
|
||||
use_react: 是否使用 ReAct 模式
|
||||
|
||||
Returns:
|
||||
RetrievalStrategyResult 包含检索结果和诊断信息
|
||||
"""
|
||||
start_time = time.time()
|
||||
context = context or {}
|
||||
|
||||
logger.info(
|
||||
f"[RetrievalStrategy] Starting retrieval: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
metadata_filter = await self._infer_metadata(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
routing_decision = await self._strategy_router.route(tenant_id, user_id)
|
||||
|
||||
mode_decision = self._mode_router.decide(
|
||||
query=query,
|
||||
confidence=metadata_filter.confidence if metadata_filter else None,
|
||||
)
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
metadata_filter=metadata_filter.filter_dict if metadata_filter else None,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
|
||||
pipeline_ctx = PipelineContext(
|
||||
retrieval_ctx=retrieval_ctx,
|
||||
metadata_filter=metadata_filter,
|
||||
use_reranker=use_reranker or self._config.reranker.enabled,
|
||||
use_react=use_react or mode_decision.mode == RuntimeMode.REACT,
|
||||
)
|
||||
|
||||
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
|
||||
|
||||
if mode_decision.mode == RuntimeMode.DIRECT and self._mode_router.should_fallback_to_react(pipeline_result):
|
||||
logger.info("[RetrievalStrategy] Falling back to react mode")
|
||||
pipeline_ctx.use_react = True
|
||||
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
|
||||
mode_decision = ModeDecision(
|
||||
mode=RuntimeMode.REACT,
|
||||
reason="fallback_from_direct",
|
||||
)
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
self._check_performance(latency_ms, tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[RetrievalStrategy] Retrieval completed: strategy={routing_decision.strategy.value}, "
|
||||
f"mode={mode_decision.mode.value}, hits={len(pipeline_result.hits)}, "
|
||||
f"latency_ms={latency_ms:.2f}"
|
||||
)
|
||||
|
||||
return RetrievalStrategyResult(
|
||||
retrieval_result=pipeline_result.retrieval_result,
|
||||
strategy_used=routing_decision.strategy,
|
||||
mode_used=mode_decision.mode,
|
||||
metadata_filter=metadata_filter,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={
|
||||
"routing_reason": routing_decision.reason,
|
||||
"mode_reason": mode_decision.reason,
|
||||
"grayscale_hit": routing_decision.grayscale_hit,
|
||||
**pipeline_result.diagnostics,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"[RetrievalStrategy] Retrieval error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
self._rollback_manager.rollback(
|
||||
trigger=RollbackTrigger.ERROR,
|
||||
reason=str(e),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return RetrievalStrategyResult(
|
||||
retrieval_result=RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e)},
|
||||
),
|
||||
strategy_used=StrategyType.DEFAULT,
|
||||
mode_used=RuntimeMode.DIRECT,
|
||||
metadata_filter=None,
|
||||
latency_ms=latency_ms,
|
||||
diagnostics={"error": str(e)},
|
||||
)
|
||||
|
||||
async def _infer_metadata(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
context: dict[str, Any],
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> MetadataFilterResult | None:
|
||||
"""执行元数据推断。"""
|
||||
try:
|
||||
if self._metadata_inference is None:
|
||||
self._metadata_inference = MetadataInferenceService(
|
||||
self._session,
|
||||
self._config.metadata_inference,
|
||||
)
|
||||
|
||||
inference_ctx = InferenceContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
existing_context=context,
|
||||
)
|
||||
|
||||
result = await self._metadata_inference.infer(inference_ctx)
|
||||
return result.filter_result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[RetrievalStrategy] Metadata inference failed: {e}")
|
||||
return None
|
||||
|
||||
def _check_performance(self, latency_ms: float, tenant_id: str | None) -> None:
|
||||
"""检查性能指标,必要时触发回退。"""
|
||||
self._rollback_manager.check_and_rollback(
|
||||
metrics={"latency_ms": latency_ms},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def get_config(self) -> RetrievalStrategyConfig:
|
||||
"""获取当前配置。"""
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||
"""更新配置。"""
|
||||
self._config = config
|
||||
self._strategy_router.update_config(config)
|
||||
self._mode_router.update_config(config.mode_router)
|
||||
self._rollback_manager.update_config(config)
|
||||
|
||||
async def health_check(self) -> dict[str, bool]:
|
||||
"""健康检查。"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
default_pipeline = await self._strategy_router._get_default_pipeline()
|
||||
results["default_pipeline"] = await default_pipeline.health_check()
|
||||
except Exception:
|
||||
results["default_pipeline"] = False
|
||||
|
||||
try:
|
||||
enhanced_pipeline = await self._strategy_router._get_enhanced_pipeline()
|
||||
results["enhanced_pipeline"] = await enhanced_pipeline.health_check()
|
||||
except Exception:
|
||||
results["enhanced_pipeline"] = False
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def create_retrieval_strategy(
|
||||
session: AsyncSession,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
) -> RetrievalStrategy:
|
||||
"""创建 RetrievalStrategy 实例。"""
|
||||
return RetrievalStrategy(session, config)
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
"""
|
||||
Rollback Manager.
|
||||
[AC-AISVC-RES-07] 策略回退与审计管理器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.strategy.config import RetrievalStrategyConfig, StrategyType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RollbackTrigger(str, Enum):
|
||||
"""回退触发原因。"""
|
||||
MANUAL = "manual"
|
||||
ERROR = "error"
|
||||
PERFORMANCE = "performance"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditLog:
|
||||
"""审计日志记录。"""
|
||||
timestamp: str
|
||||
action: str
|
||||
from_strategy: str
|
||||
to_strategy: str
|
||||
trigger: str
|
||||
reason: str
|
||||
tenant_id: str | None = None
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackResult:
|
||||
"""回退结果。"""
|
||||
success: bool
|
||||
previous_strategy: StrategyType
|
||||
current_strategy: StrategyType
|
||||
trigger: RollbackTrigger
|
||||
reason: str
|
||||
audit_log: AuditLog | None = None
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
策略回退管理器。【AC-AISVC-RES-07】
|
||||
|
||||
职责:
|
||||
1. 策略异常时回退到默认策略
|
||||
2. 记录审计日志
|
||||
3. 支持手动触发回退
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
max_audit_logs: int = 1000,
|
||||
):
|
||||
self._config = config or RetrievalStrategyConfig()
|
||||
self._max_audit_logs = max_audit_logs
|
||||
self._audit_logs: list[AuditLog] = []
|
||||
self._previous_strategy: StrategyType = StrategyType.DEFAULT
|
||||
|
||||
def rollback(
|
||||
self,
|
||||
trigger: RollbackTrigger,
|
||||
reason: str,
|
||||
tenant_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> RollbackResult:
|
||||
"""执行策略回退。【AC-AISVC-RES-07】"""
|
||||
previous = self._config.active_strategy
|
||||
current = StrategyType.DEFAULT
|
||||
|
||||
if previous == StrategyType.DEFAULT:
|
||||
return RollbackResult(
|
||||
success=False,
|
||||
previous_strategy=previous,
|
||||
current_strategy=current,
|
||||
trigger=trigger,
|
||||
reason="Already on default strategy",
|
||||
)
|
||||
|
||||
self._previous_strategy = previous
|
||||
self._config.active_strategy = current
|
||||
|
||||
audit_log = AuditLog(
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
action="rollback",
|
||||
from_strategy=previous.value,
|
||||
to_strategy=current.value,
|
||||
trigger=trigger.value,
|
||||
reason=reason,
|
||||
tenant_id=tenant_id,
|
||||
details=details or {},
|
||||
)
|
||||
|
||||
self._add_audit_log(audit_log)
|
||||
|
||||
logger.info(
|
||||
f"[RollbackManager] Rollback executed: from={previous.value}, "
|
||||
f"to={current.value}, trigger={trigger.value}"
|
||||
)
|
||||
|
||||
return RollbackResult(
|
||||
success=True,
|
||||
previous_strategy=previous,
|
||||
current_strategy=current,
|
||||
trigger=trigger,
|
||||
reason=reason,
|
||||
audit_log=audit_log,
|
||||
)
|
||||
|
||||
def check_and_rollback(
|
||||
self, metrics: dict[str, float], tenant_id: str | None = None
|
||||
) -> RollbackResult | None:
|
||||
"""检查性能指标并自动回退。【AC-AISVC-RES-08】"""
|
||||
thresholds = self._config.performance_thresholds
|
||||
|
||||
latency = metrics.get("latency_ms", 0)
|
||||
if latency > thresholds.get("max_latency_ms", 2000):
|
||||
return self.rollback(
|
||||
trigger=RollbackTrigger.PERFORMANCE,
|
||||
reason=f"Latency {latency}ms exceeds threshold",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
error_rate = metrics.get("error_rate", 0)
|
||||
if error_rate > thresholds.get("max_error_rate", 0.05):
|
||||
return self.rollback(
|
||||
trigger=RollbackTrigger.ERROR,
|
||||
reason=f"Error rate {error_rate} exceeds threshold",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _add_audit_log(self, log: AuditLog) -> None:
|
||||
self._audit_logs.append(log)
|
||||
if len(self._audit_logs) > self._max_audit_logs:
|
||||
self._audit_logs = self._audit_logs[-self._max_audit_logs:]
|
||||
|
||||
def record_audit(
|
||||
self,
|
||||
action: str,
|
||||
details: dict[str, Any],
|
||||
tenant_id: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""记录审计日志。"""
|
||||
audit_log = AuditLog(
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
action=action,
|
||||
from_strategy=self._config.active_strategy.value,
|
||||
to_strategy=self._config.active_strategy.value,
|
||||
trigger="n/a",
|
||||
reason=details.get("reason", ""),
|
||||
tenant_id=tenant_id,
|
||||
details=details,
|
||||
)
|
||||
|
||||
self._add_audit_log(audit_log)
|
||||
|
||||
logger.info(
|
||||
f"[RollbackManager] Audit recorded: action={action}, "
|
||||
f"strategy={self._config.active_strategy.value}"
|
||||
)
|
||||
|
||||
return audit_log
|
||||
|
||||
def get_audit_logs(self, limit: int = 100) -> list[AuditLog]:
|
||||
return self._audit_logs[-limit:]
|
||||
|
||||
def get_config(self) -> RetrievalStrategyConfig:
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
|
||||
_rollback_manager: RollbackManager | None = None
|
||||
|
||||
|
||||
def get_rollback_manager() -> RollbackManager:
|
||||
global _rollback_manager
|
||||
if _rollback_manager is None:
|
||||
_rollback_manager = RollbackManager()
|
||||
return _rollback_manager
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
Strategy Router.
|
||||
[AC-AISVC-RES-01~03] 策略路由器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.services.retrieval.strategy.config import (
|
||||
RetrievalStrategyConfig,
|
||||
StrategyType,
|
||||
)
|
||||
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline, get_default_pipeline
|
||||
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline, get_enhanced_pipeline
|
||||
from app.services.retrieval.strategy.pipeline_base import BasePipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""路由决策结果。"""
|
||||
strategy: StrategyType
|
||||
pipeline: BasePipeline
|
||||
reason: str
|
||||
grayscale_hit: bool = False
|
||||
|
||||
|
||||
class StrategyRouter:
|
||||
"""
|
||||
策略路由器。【AC-AISVC-RES-01~03】
|
||||
|
||||
职责:
|
||||
1. 根据配置选择默认策略或增强策略
|
||||
2. 支持灰度发布(percentage/allowlist)
|
||||
3. 不影响正在运行的默认策略请求
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RetrievalStrategyConfig | None = None,
|
||||
default_pipeline: DefaultPipeline | None = None,
|
||||
enhanced_pipeline: EnhancedPipeline | None = None,
|
||||
):
|
||||
self._config = config or RetrievalStrategyConfig()
|
||||
self._default_pipeline = default_pipeline
|
||||
self._enhanced_pipeline = enhanced_pipeline
|
||||
|
||||
async def route(
|
||||
self, tenant_id: str, user_id: str | None = None
|
||||
) -> RoutingDecision:
|
||||
"""路由到合适的策略。【AC-AISVC-RES-01~03】"""
|
||||
if self._config.active_strategy == StrategyType.ENHANCED:
|
||||
pipeline = await self._get_enhanced_pipeline()
|
||||
return RoutingDecision(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
pipeline=pipeline,
|
||||
reason="active_strategy=enhanced",
|
||||
)
|
||||
|
||||
if self._config.grayscale.should_use_enhanced(tenant_id, user_id):
|
||||
pipeline = await self._get_enhanced_pipeline()
|
||||
return RoutingDecision(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
pipeline=pipeline,
|
||||
reason="grayscale_hit",
|
||||
grayscale_hit=True,
|
||||
)
|
||||
|
||||
pipeline = await self._get_default_pipeline()
|
||||
return RoutingDecision(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
pipeline=pipeline,
|
||||
reason="default_strategy",
|
||||
)
|
||||
|
||||
async def _get_default_pipeline(self) -> DefaultPipeline:
|
||||
if self._default_pipeline is None:
|
||||
self._default_pipeline = await get_default_pipeline()
|
||||
return self._default_pipeline
|
||||
|
||||
async def _get_enhanced_pipeline(self) -> EnhancedPipeline:
|
||||
if self._enhanced_pipeline is None:
|
||||
self._enhanced_pipeline = await get_enhanced_pipeline()
|
||||
return self._enhanced_pipeline
|
||||
|
||||
def get_config(self) -> RetrievalStrategyConfig:
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||
self._config = config
|
||||
logger.info(f"[StrategyRouter] Config updated: strategy={config.active_strategy.value}")
|
||||
|
||||
|
||||
_strategy_router: StrategyRouter | None = None
|
||||
|
||||
|
||||
def get_strategy_router() -> StrategyRouter:
|
||||
global _strategy_router
|
||||
if _strategy_router is None:
|
||||
_strategy_router = StrategyRouter()
|
||||
return _strategy_router
|
||||
|
||||
|
||||
def set_strategy_router(router: StrategyRouter) -> None:
|
||||
"""Set the global strategy router instance."""
|
||||
global _strategy_router
|
||||
_strategy_router = router
|
||||
|
|
@ -0,0 +1,233 @@
|
|||
"""
|
||||
Retrieval Strategy Integration for Dialogue Flow.
|
||||
[AC-AISVC-RES-01~15] Integrates StrategyRouter and ModeRouter into dialogue pipeline.
|
||||
|
||||
Usage:
|
||||
from app.services.retrieval.strategy_integration import RetrievalStrategyIntegration
|
||||
|
||||
integration = RetrievalStrategyIntegration()
|
||||
result = await integration.execute(ctx)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.services.retrieval.routing_config import (
|
||||
RagRuntimeMode,
|
||||
StrategyType,
|
||||
RoutingConfig,
|
||||
StrategyContext,
|
||||
StrategyResult,
|
||||
)
|
||||
from app.services.retrieval.strategy_router import (
|
||||
StrategyRouter,
|
||||
get_strategy_router,
|
||||
)
|
||||
from app.services.retrieval.mode_router import (
|
||||
ModeRouter,
|
||||
ModeRouteResult,
|
||||
get_mode_router,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.retrieval.base import RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalStrategyResult:
|
||||
"""Combined result from strategy and mode routing."""
|
||||
retrieval_result: "RetrievalResult | None"
|
||||
final_answer: str | None
|
||||
strategy: StrategyType
|
||||
mode: RagRuntimeMode
|
||||
should_fallback: bool = False
|
||||
fallback_reason: str | None = None
|
||||
mode_route_result: ModeRouteResult | None = None
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
duration_ms: int = 0
|
||||
|
||||
|
||||
class RetrievalStrategyIntegration:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Integration layer for retrieval strategy.
|
||||
|
||||
Combines StrategyRouter and ModeRouter to provide a unified interface
|
||||
for the dialogue pipeline.
|
||||
|
||||
Flow:
|
||||
1. StrategyRouter selects default or enhanced strategy
|
||||
2. ModeRouter selects direct, react, or auto mode
|
||||
3. Execute retrieval with selected strategy and mode
|
||||
4. Handle fallback scenarios
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RoutingConfig | None = None,
|
||||
strategy_router: StrategyRouter | None = None,
|
||||
mode_router: ModeRouter | None = None,
|
||||
):
|
||||
self._config = config or RoutingConfig()
|
||||
self._strategy_router = strategy_router or get_strategy_router()
|
||||
self._mode_router = mode_router or get_mode_router()
|
||||
|
||||
@property
|
||||
def config(self) -> RoutingConfig:
|
||||
"""Get current configuration."""
|
||||
return self._config
|
||||
|
||||
def update_config(self, new_config: RoutingConfig) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-15] Update all routing configurations.
|
||||
"""
|
||||
self._config = new_config
|
||||
self._strategy_router.update_config(new_config)
|
||||
self._mode_router.update_config(new_config)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-15] RetrievalStrategyIntegration config updated: "
|
||||
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> RetrievalStrategyResult:
|
||||
"""
|
||||
Execute retrieval with strategy and mode routing.
|
||||
|
||||
Args:
|
||||
ctx: Strategy context with tenant, query, metadata, etc.
|
||||
|
||||
Returns:
|
||||
RetrievalStrategyResult with retrieval results and diagnostics
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
strategy_result = self._strategy_router.route(ctx)
|
||||
|
||||
mode_result = self._mode_router.route(ctx)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-01~15] Strategy routing: "
|
||||
f"strategy={strategy_result.strategy.value}, mode={mode_result.mode.value}, "
|
||||
f"tenant={ctx.tenant_id}, query_len={len(ctx.query)}"
|
||||
)
|
||||
|
||||
retrieval_result = None
|
||||
final_answer = None
|
||||
should_fallback = False
|
||||
fallback_reason = None
|
||||
|
||||
try:
|
||||
if mode_result.mode == RagRuntimeMode.DIRECT:
|
||||
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
|
||||
|
||||
if answer is not None:
|
||||
final_answer = answer
|
||||
should_fallback = mode_result.should_fallback_to_react
|
||||
fallback_reason = mode_result.fallback_reason
|
||||
|
||||
elif mode_result.mode == RagRuntimeMode.REACT:
|
||||
answer, retrieval_result, react_ctx = await self._mode_router.execute_react(ctx)
|
||||
final_answer = answer
|
||||
|
||||
else:
|
||||
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
|
||||
|
||||
if answer is not None:
|
||||
final_answer = answer
|
||||
should_fallback = mode_result.should_fallback_to_react
|
||||
fallback_reason = mode_result.fallback_reason
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-AISVC-RES-07] Retrieval strategy execution failed: {e}"
|
||||
)
|
||||
|
||||
if strategy_result.strategy == StrategyType.ENHANCED:
|
||||
self._strategy_router.rollback(
|
||||
reason=str(e),
|
||||
tenant_id=ctx.tenant_id,
|
||||
)
|
||||
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
retriever = await get_optimized_retriever()
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
metadata_filter=ctx.metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
retrieval_result = await retriever.retrieve(retrieval_ctx)
|
||||
|
||||
should_fallback = True
|
||||
fallback_reason = str(e)
|
||||
|
||||
else:
|
||||
raise
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return RetrievalStrategyResult(
|
||||
retrieval_result=retrieval_result,
|
||||
final_answer=final_answer,
|
||||
strategy=strategy_result.strategy,
|
||||
mode=mode_result.mode,
|
||||
should_fallback=should_fallback,
|
||||
fallback_reason=fallback_reason,
|
||||
mode_route_result=mode_result,
|
||||
diagnostics={
|
||||
"strategy_diagnostics": strategy_result.diagnostics,
|
||||
"mode_diagnostics": mode_result.diagnostics,
|
||||
"duration_ms": duration_ms,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
def get_current_strategy(self) -> StrategyType:
|
||||
"""Get current active strategy."""
|
||||
return self._strategy_router.current_strategy
|
||||
|
||||
def get_rollback_records(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""Get recent rollback records."""
|
||||
records = self._strategy_router.get_rollback_records(limit)
|
||||
return [
|
||||
{
|
||||
"timestamp": r.timestamp,
|
||||
"from_strategy": r.from_strategy.value,
|
||||
"to_strategy": r.to_strategy.value,
|
||||
"reason": r.reason,
|
||||
"tenant_id": r.tenant_id,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
def validate_config(self) -> tuple[bool, list[str]]:
|
||||
"""Validate current configuration."""
|
||||
return self._config.validate()
|
||||
|
||||
|
||||
_integration: RetrievalStrategyIntegration | None = None
|
||||
|
||||
|
||||
def get_retrieval_strategy_integration() -> RetrievalStrategyIntegration:
|
||||
"""Get or create RetrievalStrategyIntegration singleton."""
|
||||
global _integration
|
||||
if _integration is None:
|
||||
_integration = RetrievalStrategyIntegration()
|
||||
return _integration
|
||||
|
||||
|
||||
def reset_retrieval_strategy_integration() -> None:
|
||||
"""Reset RetrievalStrategyIntegration singleton (for testing)."""
|
||||
global _integration
|
||||
_integration = None
|
||||
|
|
@ -0,0 +1,403 @@
|
|||
"""
|
||||
Strategy Router for Retrieval and Embedding.
|
||||
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy.
|
||||
|
||||
Key Features:
|
||||
- Default strategy preserves existing online logic
|
||||
- Enhanced strategy is configurable and can be rolled back
|
||||
- Supports grayscale release (percentage/allowlist)
|
||||
- Supports rollback on error or performance degradation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.services.retrieval.routing_config import (
|
||||
RagRuntimeMode,
|
||||
StrategyType,
|
||||
RoutingConfig,
|
||||
StrategyContext,
|
||||
StrategyResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.retrieval.base import RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackRecord:
|
||||
"""Record for strategy rollback event."""
|
||||
timestamp: float
|
||||
from_strategy: StrategyType
|
||||
to_strategy: StrategyType
|
||||
reason: str
|
||||
tenant_id: str | None = None
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
[AC-AISVC-RES-07] Manages strategy rollback and audit logging.
|
||||
"""
|
||||
|
||||
def __init__(self, max_records: int = 100):
|
||||
self._records: list[RollbackRecord] = []
|
||||
self._max_records = max_records
|
||||
|
||||
def record_rollback(
|
||||
self,
|
||||
from_strategy: StrategyType,
|
||||
to_strategy: StrategyType,
|
||||
reason: str,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Record a rollback event."""
|
||||
record = RollbackRecord(
|
||||
timestamp=time.time(),
|
||||
from_strategy=from_strategy,
|
||||
to_strategy=to_strategy,
|
||||
reason=reason,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
self._records.append(record)
|
||||
|
||||
if len(self._records) > self._max_records:
|
||||
self._records = self._records[-self._max_records:]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, "
|
||||
f"reason={reason}, tenant={tenant_id}"
|
||||
)
|
||||
|
||||
def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]:
|
||||
"""Get recent rollback records."""
|
||||
return self._records[-limit:]
|
||||
|
||||
def get_rollback_count(self, since_timestamp: float | None = None) -> int:
|
||||
"""Get count of rollbacks, optionally since a timestamp."""
|
||||
if since_timestamp is None:
|
||||
return len(self._records)
|
||||
|
||||
return sum(1 for r in self._records if r.timestamp >= since_timestamp)
|
||||
|
||||
|
||||
class DefaultPipeline:
|
||||
"""
|
||||
[AC-AISVC-RES-01] Default pipeline that preserves existing online logic.
|
||||
|
||||
This pipeline uses the existing OptimizedRetriever without any new features.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._retriever = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> "RetrievalResult":
|
||||
"""
|
||||
Execute default retrieval strategy.
|
||||
|
||||
Uses existing OptimizedRetriever with current configuration.
|
||||
"""
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
if self._retriever is None:
|
||||
self._retriever = await get_optimized_retriever()
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
metadata_filter=ctx.metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
return await self._retriever.retrieve(retrieval_ctx)
|
||||
|
||||
|
||||
class EnhancedPipeline:
|
||||
"""
|
||||
[AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features.
|
||||
|
||||
Features:
|
||||
- Document preprocessing (cleaning/normalization)
|
||||
- Structured chunking (markdown/tables/FAQ)
|
||||
- Metadata generation and mounting
|
||||
- Embedding strategy (document/query prefix + Matryoshka)
|
||||
- Metadata inference and filtering (hard/soft filter)
|
||||
- Retrieval strategy (Dense + Keyword/Hybrid + RRF)
|
||||
- Optional reranking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RoutingConfig | None = None,
|
||||
):
|
||||
self._config = config or RoutingConfig()
|
||||
self._retriever = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> "RetrievalResult":
|
||||
"""
|
||||
Execute enhanced retrieval strategy.
|
||||
|
||||
Uses OptimizedRetriever with enhanced configuration.
|
||||
"""
|
||||
from app.services.retrieval.optimized_retriever import OptimizedRetriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
if self._retriever is None:
|
||||
self._retriever = OptimizedRetriever(
|
||||
two_stage_enabled=True,
|
||||
hybrid_enabled=True,
|
||||
)
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.query,
|
||||
metadata_filter=ctx.metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
|
||||
return await self._retriever.retrieve(retrieval_ctx)
|
||||
|
||||
|
||||
class StrategyRouter:
|
||||
"""
|
||||
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03]
|
||||
Strategy router for retrieval and embedding.
|
||||
|
||||
Decision Flow:
|
||||
1. Check if enhanced strategy is enabled via configuration
|
||||
2. Check grayscale rules (percentage/allowlist)
|
||||
3. Route to appropriate pipeline (default/enhanced)
|
||||
4. Handle rollback on error or performance degradation
|
||||
|
||||
Constraints:
|
||||
- Default strategy MUST preserve existing online logic
|
||||
- Enhanced strategy MUST be configurable and rollback-able
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RoutingConfig | None = None,
|
||||
rollback_manager: RollbackManager | None = None,
|
||||
):
|
||||
self._config = config or RoutingConfig()
|
||||
self._rollback_manager = rollback_manager or RollbackManager()
|
||||
self._default_pipeline = DefaultPipeline()
|
||||
self._enhanced_pipeline = EnhancedPipeline(self._config)
|
||||
|
||||
self._current_strategy = StrategyType.DEFAULT
|
||||
self._strategy_enabled = True
|
||||
|
||||
@property
|
||||
def current_strategy(self) -> StrategyType:
|
||||
"""Get current active strategy."""
|
||||
return self._current_strategy
|
||||
|
||||
@property
|
||||
def config(self) -> RoutingConfig:
|
||||
"""Get current configuration."""
|
||||
return self._config
|
||||
|
||||
def update_config(self, new_config: RoutingConfig) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-15] Update routing configuration (hot reload).
|
||||
|
||||
Args:
|
||||
new_config: New configuration to apply
|
||||
"""
|
||||
old_strategy = self._config.strategy
|
||||
self._config = new_config
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-15] Routing config updated: "
|
||||
f"strategy={new_config.strategy.value}, "
|
||||
f"mode={new_config.rag_runtime_mode.value}, "
|
||||
f"grayscale={new_config.grayscale_percentage:.2%}"
|
||||
)
|
||||
|
||||
if old_strategy != new_config.strategy:
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}"
|
||||
)
|
||||
|
||||
def route(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> StrategyResult:
|
||||
"""
|
||||
[AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy.
|
||||
|
||||
Args:
|
||||
ctx: Strategy context with tenant, query, metadata, etc.
|
||||
|
||||
Returns:
|
||||
StrategyResult with selected strategy and mode
|
||||
"""
|
||||
if not self._strategy_enabled:
|
||||
logger.info("[AC-AISVC-RES-07] Strategy disabled, using default")
|
||||
return StrategyResult(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
mode=self._config.rag_runtime_mode,
|
||||
should_fallback=False,
|
||||
diagnostics={"reason": "strategy_disabled"},
|
||||
)
|
||||
|
||||
use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id)
|
||||
|
||||
if use_enhanced:
|
||||
self._current_strategy = StrategyType.ENHANCED
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}"
|
||||
)
|
||||
else:
|
||||
self._current_strategy = StrategyType.DEFAULT
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}"
|
||||
)
|
||||
|
||||
return StrategyResult(
|
||||
strategy=self._current_strategy,
|
||||
mode=self._config.rag_runtime_mode,
|
||||
diagnostics={
|
||||
"grayscale_percentage": self._config.grayscale_percentage,
|
||||
"in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False,
|
||||
},
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
ctx: StrategyContext,
|
||||
) -> tuple["RetrievalResult", StrategyResult]:
|
||||
"""
|
||||
Execute retrieval with strategy routing.
|
||||
|
||||
Args:
|
||||
ctx: Strategy context
|
||||
|
||||
Returns:
|
||||
Tuple of (RetrievalResult, StrategyResult)
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = self.route(ctx)
|
||||
|
||||
try:
|
||||
if result.strategy == StrategyType.ENHANCED:
|
||||
retrieval_result = await self._enhanced_pipeline.execute(ctx)
|
||||
else:
|
||||
retrieval_result = await self._default_pipeline.execute(ctx)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if duration_ms > self._config.performance_budget_ms:
|
||||
degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms
|
||||
if degradation > self._config.performance_degradation_threshold:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-RES-08] Performance degradation detected: "
|
||||
f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, "
|
||||
f"degradation={degradation:.2%}"
|
||||
)
|
||||
|
||||
return retrieval_result, result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-AISVC-RES-07] Strategy execution failed: {e}, "
|
||||
f"strategy={result.strategy.value}"
|
||||
)
|
||||
|
||||
if result.strategy == StrategyType.ENHANCED:
|
||||
self._rollback_manager.record_rollback(
|
||||
from_strategy=StrategyType.ENHANCED,
|
||||
to_strategy=StrategyType.DEFAULT,
|
||||
reason=str(e),
|
||||
tenant_id=ctx.tenant_id,
|
||||
)
|
||||
|
||||
logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy")
|
||||
|
||||
retrieval_result = await self._default_pipeline.execute(ctx)
|
||||
|
||||
return retrieval_result, StrategyResult(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
mode=result.mode,
|
||||
should_fallback=True,
|
||||
fallback_reason=str(e),
|
||||
diagnostics=result.diagnostics,
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
def rollback(
|
||||
self,
|
||||
reason: str,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-07] Force rollback to default strategy.
|
||||
|
||||
Args:
|
||||
reason: Reason for rollback
|
||||
tenant_id: Optional tenant ID for audit
|
||||
request_id: Optional request ID for audit
|
||||
"""
|
||||
if self._current_strategy == StrategyType.ENHANCED:
|
||||
self._rollback_manager.record_rollback(
|
||||
from_strategy=StrategyType.ENHANCED,
|
||||
to_strategy=StrategyType.DEFAULT,
|
||||
reason=reason,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
self._current_strategy = StrategyType.DEFAULT
|
||||
self._config.strategy = StrategyType.DEFAULT
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}"
|
||||
)
|
||||
|
||||
def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]:
|
||||
"""Get recent rollback records."""
|
||||
return self._rollback_manager.get_recent_rollbacks(limit)
|
||||
|
||||
def validate_config(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
[AC-AISVC-RES-06] Validate current configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list of error messages)
|
||||
"""
|
||||
return self._config.validate()
|
||||
|
||||
|
||||
_strategy_router: StrategyRouter | None = None
|
||||
|
||||
|
||||
def get_strategy_router() -> StrategyRouter:
|
||||
"""Get or create StrategyRouter singleton."""
|
||||
global _strategy_router
|
||||
if _strategy_router is None:
|
||||
_strategy_router = StrategyRouter()
|
||||
return _strategy_router
|
||||
|
||||
|
||||
def reset_strategy_router() -> None:
|
||||
"""Reset StrategyRouter singleton (for testing)."""
|
||||
global _strategy_router
|
||||
_strategy_router = None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue