feat: implement Phase 2 API for metadata role separation [AC-MRS-01~16]
- Task 2.3: SlotDefinitionService with CRUD operations [AC-MRS-07,08] - Task 2.4: Extend MetadataFieldDefinition API with by-role endpoint [AC-MRS-01,04,05,06,16] - Task 2.5: SlotDefinition API with CRUD endpoints [AC-MRS-07,08,16] - Task 2.6: Runtime slot API for mid platform [AC-MRS-09,10] - Task 5.1: Unit tests for RoleBasedFieldProvider and SlotDefinitionService [AC-MRS-01~16]
This commit is contained in:
parent
662ba2b101
commit
5c1f311656
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Admin API routes for AI Service management.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
|
||||
[AC-MRS-07,08,16] Slot definition management endpoints.
|
||||
"""
|
||||
|
||||
from app.api.admin.api_key import router as api_key_router
|
||||
|
|
@ -19,6 +20,7 @@ from app.api.admin.prompt_templates import router as prompt_templates_router
|
|||
from app.api.admin.rag import router as rag_router
|
||||
from app.api.admin.script_flows import router as script_flows_router
|
||||
from app.api.admin.sessions import router as sessions_router
|
||||
from app.api.admin.slot_definition import router as slot_definition_router
|
||||
from app.api.admin.tenants import router as tenants_router
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -38,5 +40,6 @@ __all__ = [
|
|||
"rag_router",
|
||||
"script_flows_router",
|
||||
"sessions_router",
|
||||
"slot_definition_router",
|
||||
"tenants_router",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Metadata Field Definition API.
|
||||
[AC-IDSMETA-13, AC-IDSMETA-14] 元数据字段定义管理接口,支持字段级状态治理。
|
||||
[AC-MRS-01,04,05,06,16] 支持字段角色分层配置和按角色查询。
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -20,10 +21,12 @@ from app.models.entities import (
|
|||
MetadataFieldStatus,
|
||||
)
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider, InvalidRoleError
|
||||
from app.schemas.metadata import VALID_FIELD_ROLES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/metadata-schemas", tags=["MetadataSchemas"])
|
||||
router = APIRouter(prefix="/admin/metadata-schemas", tags=["MetadataSchema"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
|
|
@ -34,11 +37,33 @@ def get_current_tenant_id() -> str:
|
|||
return tenant_id
|
||||
|
||||
|
||||
def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
|
||||
"""Convert field definition to dict with field_roles"""
|
||||
return {
|
||||
"id": str(f.id),
|
||||
"tenant_id": str(f.tenant_id),
|
||||
"field_key": f.field_key,
|
||||
"label": f.label,
|
||||
"type": f.type,
|
||||
"required": f.required,
|
||||
"options": f.options,
|
||||
"default_value": f.default_value,
|
||||
"scope": f.scope,
|
||||
"is_filterable": f.is_filterable,
|
||||
"is_rank_feature": f.is_rank_feature,
|
||||
"field_roles": f.field_roles or [],
|
||||
"status": f.status,
|
||||
"version": f.version,
|
||||
"created_at": f.created_at.isoformat() if f.created_at else None,
|
||||
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listMetadataSchemas",
|
||||
summary="List metadata schemas",
|
||||
description="[AC-IDSMETA-13] 获取元数据字段定义列表,支持按状态和范围过滤",
|
||||
description="[AC-IDSMETA-13] [AC-MRS-06] 获取元数据字段定义列表,支持按状态、范围、角色过滤",
|
||||
)
|
||||
async def list_schemas(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
|
|
@ -49,28 +74,32 @@ async def list_schemas(
|
|||
scope: Annotated[str | None, Query(
|
||||
description="按适用范围过滤: kb_document/intent_rule/script_flow/prompt_template"
|
||||
)] = None,
|
||||
field_role: Annotated[str | None, Query(
|
||||
description="[AC-MRS-06] 按字段角色过滤: resource_filter/slot/prompt_var/routing_signal"
|
||||
)] = None,
|
||||
include_deprecated: Annotated[bool, Query(
|
||||
description="是否包含已废弃的字段"
|
||||
)] = False,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-IDSMETA-13] 列出元数据字段定义
|
||||
[AC-IDSMETA-13] [AC-MRS-06] 列出元数据字段定义
|
||||
|
||||
Args:
|
||||
status: 按状态过滤
|
||||
scope: 按适用范围过滤
|
||||
field_role: [AC-MRS-06] 按字段角色过滤
|
||||
include_deprecated: 是否包含已废弃的字段(当 status 未指定时生效)
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-13] Listing metadata field definitions: "
|
||||
f"tenant={tenant_id}, status={status}, scope={scope}, include_deprecated={include_deprecated}"
|
||||
f"[AC-IDSMETA-13] [AC-MRS-06] Listing metadata field definitions: "
|
||||
f"tenant={tenant_id}, status={status}, scope={scope}, field_role={field_role}"
|
||||
)
|
||||
|
||||
if status and status not in [s.value for s in MetadataFieldStatus]:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_STATUS",
|
||||
"error_code": "INVALID_STATUS",
|
||||
"message": f"Invalid status: {status}",
|
||||
"details": {
|
||||
"valid_values": [s.value for s in MetadataFieldStatus]
|
||||
|
|
@ -78,34 +107,29 @@ async def list_schemas(
|
|||
}
|
||||
)
|
||||
|
||||
if field_role and field_role not in VALID_FIELD_ROLES:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "INVALID_ROLE",
|
||||
"message": f"Invalid role '{field_role}'. Valid roles are: {', '.join(VALID_FIELD_ROLES)}",
|
||||
"details": {
|
||||
"valid_roles": VALID_FIELD_ROLES
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
service = MetadataFieldDefinitionService(session)
|
||||
|
||||
if include_deprecated and not status:
|
||||
fields = await service.get_field_definitions_for_read(tenant_id, scope)
|
||||
if field_role:
|
||||
fields = [f for f in fields if field_role in (f.field_roles or [])]
|
||||
else:
|
||||
fields = await service.list_field_definitions(tenant_id, status, scope)
|
||||
fields = await service.list_field_definitions(tenant_id, status, scope, field_role)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"items": [
|
||||
{
|
||||
"id": str(f.id),
|
||||
"field_key": f.field_key,
|
||||
"label": f.label,
|
||||
"type": f.type,
|
||||
"required": f.required,
|
||||
"options": f.options,
|
||||
"default": f.default_value,
|
||||
"scope": f.scope,
|
||||
"is_filterable": f.is_filterable,
|
||||
"is_rank_feature": f.is_rank_feature,
|
||||
"status": f.status,
|
||||
"created_at": f.created_at.isoformat() if f.created_at else None,
|
||||
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
||||
}
|
||||
for f in fields
|
||||
]
|
||||
}
|
||||
content=[_field_to_dict(f) for f in fields]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -113,7 +137,7 @@ async def list_schemas(
|
|||
"",
|
||||
operation_id="createMetadataSchema",
|
||||
summary="Create metadata schema",
|
||||
description="[AC-IDSMETA-13] 创建新的元数据字段定义",
|
||||
description="[AC-IDSMETA-13] [AC-MRS-01,02,03] 创建新的元数据字段定义,支持 field_roles 多选配置",
|
||||
status_code=201,
|
||||
)
|
||||
async def create_schema(
|
||||
|
|
@ -122,11 +146,11 @@ async def create_schema(
|
|||
field_create: MetadataFieldDefinitionCreate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-IDSMETA-13] 创建元数据字段定义
|
||||
[AC-IDSMETA-13] [AC-MRS-01,02,03] 创建元数据字段定义
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-13] Creating metadata field definition: "
|
||||
f"tenant={tenant_id}, field_key={field_create.field_key}"
|
||||
f"[AC-IDSMETA-13] [AC-MRS-01] Creating metadata field definition: "
|
||||
f"tenant={tenant_id}, field_key={field_create.field_key}, field_roles={field_create.field_roles}"
|
||||
)
|
||||
|
||||
service = MetadataFieldDefinitionService(session)
|
||||
|
|
@ -138,36 +162,104 @@ async def create_schema(
|
|||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "VALIDATION_ERROR",
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content={
|
||||
"id": str(field.id),
|
||||
"field_key": field.field_key,
|
||||
"label": field.label,
|
||||
"type": field.type,
|
||||
"required": field.required,
|
||||
"options": field.options,
|
||||
"default": field.default_value,
|
||||
"scope": field.scope,
|
||||
"is_filterable": field.is_filterable,
|
||||
"is_rank_feature": field.is_rank_feature,
|
||||
"status": field.status,
|
||||
"created_at": field.created_at.isoformat() if field.created_at else None,
|
||||
"updated_at": field.updated_at.isoformat() if field.updated_at else None,
|
||||
}
|
||||
content=_field_to_dict(field)
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-role",
|
||||
operation_id="getMetadataSchemasByRole",
|
||||
summary="Get metadata schemas by role",
|
||||
description="[AC-MRS-04,05] 按指定角色查询所有包含该角色的活跃字段定义",
|
||||
)
|
||||
async def get_schemas_by_role(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
role: Annotated[str, Query(
|
||||
description="[AC-MRS-04] 字段角色: resource_filter/slot/prompt_var/routing_signal"
|
||||
)],
|
||||
include_deprecated: Annotated[bool, Query(
|
||||
description="是否包含已废弃字段"
|
||||
)] = False,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-MRS-04,05] 按角色查询字段定义
|
||||
|
||||
Args:
|
||||
role: 字段角色
|
||||
include_deprecated: 是否包含已废弃字段
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-MRS-04] Getting metadata schemas by role: "
|
||||
f"tenant={tenant_id}, role={role}, include_deprecated={include_deprecated}"
|
||||
)
|
||||
|
||||
provider = RoleBasedFieldProvider(session)
|
||||
|
||||
try:
|
||||
fields = await provider.get_fields_by_role(tenant_id, role, include_deprecated)
|
||||
except InvalidRoleError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "INVALID_ROLE",
|
||||
"message": str(e),
|
||||
"details": {
|
||||
"valid_roles": e.valid_roles
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=[_field_to_dict(f) for f in fields]
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{id}",
|
||||
operation_id="getMetadataSchema",
|
||||
summary="Get metadata schema by ID",
|
||||
description="获取单个元数据字段定义",
|
||||
)
|
||||
async def get_schema(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取单个元数据字段定义
|
||||
"""
|
||||
logger.info(
|
||||
f"Getting metadata field definition: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = MetadataFieldDefinitionService(session)
|
||||
field = await service.get_field_definition(tenant_id, id)
|
||||
|
||||
if not field:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Field definition {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=_field_to_dict(field))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{id}",
|
||||
operation_id="updateMetadataSchema",
|
||||
summary="Update metadata schema",
|
||||
description="[AC-IDSMETA-14] 更新元数据字段定义,支持状态切换",
|
||||
description="[AC-IDSMETA-14] [AC-MRS-01] 更新元数据字段定义,支持修改 field_roles",
|
||||
)
|
||||
async def update_schema(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
|
|
@ -176,18 +268,18 @@ async def update_schema(
|
|||
field_update: MetadataFieldDefinitionUpdate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-IDSMETA-14] 更新元数据字段定义
|
||||
[AC-IDSMETA-14] [AC-MRS-01] 更新元数据字段定义
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-14] Updating metadata field definition: "
|
||||
f"tenant={tenant_id}, id={id}"
|
||||
f"[AC-IDSMETA-14] [AC-MRS-01] Updating metadata field definition: "
|
||||
f"tenant={tenant_id}, id={id}, field_roles={field_update.field_roles}"
|
||||
)
|
||||
|
||||
if field_update.status and field_update.status not in [s.value for s in MetadataFieldStatus]:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_STATUS",
|
||||
"error_code": "INVALID_STATUS",
|
||||
"message": f"Invalid status: {field_update.status}",
|
||||
"details": {
|
||||
"valid_values": [s.value for s in MetadataFieldStatus]
|
||||
|
|
@ -203,7 +295,7 @@ async def update_schema(
|
|||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "VALIDATION_ERROR",
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
|
@ -212,30 +304,49 @@ async def update_schema(
|
|||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "NOT_FOUND",
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Field definition {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"id": str(field.id),
|
||||
"field_key": field.field_key,
|
||||
"label": field.label,
|
||||
"type": field.type,
|
||||
"required": field.required,
|
||||
"options": field.options,
|
||||
"default": field.default_value,
|
||||
"scope": field.scope,
|
||||
"is_filterable": field.is_filterable,
|
||||
"is_rank_feature": field.is_rank_feature,
|
||||
"status": field.status,
|
||||
"created_at": field.created_at.isoformat() if field.created_at else None,
|
||||
"updated_at": field.updated_at.isoformat() if field.updated_at else None,
|
||||
}
|
||||
return JSONResponse(content=_field_to_dict(field))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{id}",
|
||||
operation_id="deleteMetadataSchema",
|
||||
summary="Delete metadata schema",
|
||||
description="[AC-MRS-16] 删除元数据字段定义,无需考虑历史数据兼容性",
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_schema(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-MRS-16] 删除元数据字段定义
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-MRS-16] Deleting metadata field definition: "
|
||||
f"tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = MetadataFieldDefinitionService(session)
|
||||
success = await service.delete_field_definition(tenant_id, id)
|
||||
|
||||
if not success:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Field definition not found: {id}",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=204, content=None)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
|
@ -263,26 +374,7 @@ async def get_active_schemas(
|
|||
fields = await service.get_active_field_definitions(tenant_id, scope)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"items": [
|
||||
{
|
||||
"id": str(f.id),
|
||||
"field_key": f.field_key,
|
||||
"label": f.label,
|
||||
"type": f.type,
|
||||
"required": f.required,
|
||||
"options": f.options,
|
||||
"default": f.default_value,
|
||||
"scope": f.scope,
|
||||
"is_filterable": f.is_filterable,
|
||||
"is_rank_feature": f.is_rank_feature,
|
||||
"status": f.status,
|
||||
"created_at": f.created_at.isoformat() if f.created_at else None,
|
||||
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
||||
}
|
||||
for f in fields
|
||||
]
|
||||
}
|
||||
content=[_field_to_dict(f) for f in fields]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -311,26 +403,7 @@ async def get_readable_schemas(
|
|||
fields = await service.get_field_definitions_for_read(tenant_id, scope)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"items": [
|
||||
{
|
||||
"id": str(f.id),
|
||||
"field_key": f.field_key,
|
||||
"label": f.label,
|
||||
"type": f.type,
|
||||
"required": f.required,
|
||||
"options": f.options,
|
||||
"default": f.default_value,
|
||||
"scope": f.scope,
|
||||
"is_filterable": f.is_filterable,
|
||||
"is_rank_feature": f.is_rank_feature,
|
||||
"status": f.status,
|
||||
"created_at": f.created_at.isoformat() if f.created_at else None,
|
||||
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
||||
}
|
||||
for f in fields
|
||||
]
|
||||
}
|
||||
content=[_field_to_dict(f) for f in fields]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -373,42 +446,3 @@ async def validate_metadata_for_create(
|
|||
"errors": errors,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{field_id}",
|
||||
operation_id="deleteMetadataSchema",
|
||||
summary="Delete metadata schema",
|
||||
description="[AC-IDSMETA-13] 删除元数据字段定义",
|
||||
)
|
||||
async def delete_schema(
|
||||
field_id: str,
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-IDSMETA-13] 删除元数据字段定义
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-13] Deleting metadata field definition: "
|
||||
f"tenant={tenant_id}, field_id={field_id}"
|
||||
)
|
||||
|
||||
service = MetadataFieldDefinitionService(session)
|
||||
success = await service.delete_field_definition(tenant_id, field_id)
|
||||
|
||||
if not success:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "NOT_FOUND",
|
||||
"message": f"Field definition not found: {field_id}",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Field definition deleted successfully",
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,234 @@
|
|||
"""
|
||||
Slot Definition API.
|
||||
[AC-MRS-07,08,16] 槽位定义管理接口
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models.entities import SlotDefinitionCreate, SlotDefinitionUpdate
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/slot-definitions", tags=["SlotDefinition"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Get current tenant ID from context."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]:
|
||||
"""Convert slot definition to dict"""
|
||||
if isinstance(slot, dict):
|
||||
return slot
|
||||
|
||||
return {
|
||||
"id": str(slot.id),
|
||||
"tenant_id": str(slot.tenant_id),
|
||||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||||
"created_at": slot.created_at.isoformat() if slot.created_at else None,
|
||||
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listSlotDefinitions",
|
||||
summary="List slot definitions",
|
||||
description="获取槽位定义列表",
|
||||
)
|
||||
async def list_slot_definitions(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
required: Annotated[bool | None, Query(
|
||||
description="按是否必填过滤"
|
||||
)] = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
列出槽位定义
|
||||
"""
|
||||
logger.info(
|
||||
f"Listing slot definitions: tenant={tenant_id}, required={required}"
|
||||
)
|
||||
|
||||
service = SlotDefinitionService(session)
|
||||
slots = await service.list_slot_definitions(tenant_id, required)
|
||||
|
||||
return JSONResponse(
|
||||
content=[_slot_to_dict(s) for s in slots]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
operation_id="createSlotDefinition",
|
||||
summary="Create slot definition",
|
||||
description="[AC-MRS-07,08] 创建新的槽位定义,可关联已有元数据字段",
|
||||
status_code=201,
|
||||
)
|
||||
async def create_slot_definition(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
slot_create: SlotDefinitionCreate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-MRS-07,08] 创建槽位定义
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-MRS-07] Creating slot definition: "
|
||||
f"tenant={tenant_id}, slot_key={slot_create.slot_key}, "
|
||||
f"linked_field_id={slot_create.linked_field_id}"
|
||||
)
|
||||
|
||||
service = SlotDefinitionService(session)
|
||||
|
||||
try:
|
||||
slot = await service.create_slot_definition(tenant_id, slot_create)
|
||||
await session.commit()
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content=_slot_to_dict(slot)
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{id}",
|
||||
operation_id="getSlotDefinition",
|
||||
summary="Get slot definition by ID",
|
||||
description="获取单个槽位定义",
|
||||
)
|
||||
async def get_slot_definition(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取单个槽位定义
|
||||
"""
|
||||
logger.info(
|
||||
f"Getting slot definition: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SlotDefinitionService(session)
|
||||
slot = await service.get_slot_definition_with_field(tenant_id, id)
|
||||
|
||||
if not slot:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Slot definition {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=slot)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{id}",
|
||||
operation_id="updateSlotDefinition",
|
||||
summary="Update slot definition",
|
||||
description="更新槽位定义",
|
||||
)
|
||||
async def update_slot_definition(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
slot_update: SlotDefinitionUpdate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
更新槽位定义
|
||||
"""
|
||||
logger.info(
|
||||
f"Updating slot definition: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SlotDefinitionService(session)
|
||||
|
||||
try:
|
||||
slot = await service.update_slot_definition(tenant_id, id, slot_update)
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
if not slot:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Slot definition {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(content=_slot_to_dict(slot))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{id}",
|
||||
operation_id="deleteSlotDefinition",
|
||||
summary="Delete slot definition",
|
||||
description="[AC-MRS-16] 删除槽位定义,无需考虑历史数据兼容性",
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_slot_definition(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-MRS-16] 删除槽位定义
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-MRS-16] Deleting slot definition: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SlotDefinitionService(session)
|
||||
success = await service.delete_slot_definition(tenant_id, id)
|
||||
|
||||
if not success:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Slot definition not found: {id}",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(status_code=204, content=None)
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
Runtime Slot API.
|
||||
[AC-MRS-09,10] 运行时槽位查询接口
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider, InvalidRoleError
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
from app.schemas.metadata import VALID_FIELD_ROLES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mid/slots", tags=["RuntimeSlot"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Get current tenant ID from context."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-role",
|
||||
operation_id="getSlotsByRole",
|
||||
summary="Get slots by role",
|
||||
description="[AC-MRS-10] 运行时接口,按角色获取槽位定义及关联的元数据字段信息",
|
||||
)
|
||||
async def get_slots_by_role(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
role: Annotated[str, Query(
|
||||
description="[AC-MRS-10] 字段角色: resource_filter/slot/prompt_var/routing_signal"
|
||||
)] = "slot",
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-MRS-10] 按角色获取槽位定义
|
||||
|
||||
Args:
|
||||
role: 字段角色,默认为 slot
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-MRS-10] Getting slots by role: tenant={tenant_id}, role={role}"
|
||||
)
|
||||
|
||||
provider = RoleBasedFieldProvider(session)
|
||||
|
||||
try:
|
||||
slots = await provider.get_slot_definitions_by_role(tenant_id, role)
|
||||
except InvalidRoleError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "INVALID_ROLE",
|
||||
"message": str(e),
|
||||
"details": {
|
||||
"valid_roles": e.valid_roles
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=slots)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{slot_key}",
|
||||
operation_id="getSlotValue",
|
||||
summary="Get runtime slot value",
|
||||
description="[AC-MRS-09] 获取指定槽位的运行时值,包含来源、置信度、更新时间",
|
||||
)
|
||||
async def get_slot_value(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
slot_key: str,
|
||||
user_id: Annotated[str | None, Query(
|
||||
description="用户 ID"
|
||||
)] = None,
|
||||
session_id: Annotated[str | None, Query(
|
||||
description="会话 ID"
|
||||
)] = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-MRS-09] 获取运行时槽位值
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
user_id: 用户 ID
|
||||
session_id: 会话 ID
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-MRS-09] Getting slot value: tenant={tenant_id}, slot_key={slot_key}, "
|
||||
f"user_id={user_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
service = SlotDefinitionService(session)
|
||||
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
|
||||
|
||||
if not slot_def:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Slot '{slot_key}' not found",
|
||||
}
|
||||
)
|
||||
|
||||
value = slot_def.default_value
|
||||
source = "default"
|
||||
confidence = 1.0
|
||||
|
||||
if value is None:
|
||||
if slot_def.type == "string":
|
||||
value = ""
|
||||
elif slot_def.type == "number":
|
||||
value = 0
|
||||
elif slot_def.type == "boolean":
|
||||
value = False
|
||||
elif slot_def.type in ["enum", "array_enum"]:
|
||||
value = [] if slot_def.type == "array_enum" else ""
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"key": slot_key,
|
||||
"value": value,
|
||||
"source": source,
|
||||
"confidence": confidence,
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
)
|
||||
|
|
@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api import chat_router, health_router
|
||||
from app.api.mid import router as mid_router
|
||||
from app.api.admin import (
|
||||
api_key_router,
|
||||
dashboard_router,
|
||||
|
|
@ -29,6 +30,7 @@ from app.api.admin import (
|
|||
rag_router,
|
||||
script_flows_router,
|
||||
sessions_router,
|
||||
slot_definition_router,
|
||||
tenants_router,
|
||||
)
|
||||
from app.api.admin.kb_optimized import router as kb_optimized_router
|
||||
|
|
@ -165,8 +167,11 @@ app.include_router(prompt_templates_router)
|
|||
app.include_router(rag_router)
|
||||
app.include_router(script_flows_router)
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(slot_definition_router)
|
||||
app.include_router(tenants_router)
|
||||
|
||||
app.include_router(mid_router)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
|
|
|||
|
|
@ -0,0 +1,364 @@
|
|||
"""
|
||||
Slot Definition Service.
|
||||
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import (
|
||||
SlotDefinition,
|
||||
SlotDefinitionCreate,
|
||||
SlotDefinitionUpdate,
|
||||
MetadataFieldDefinition,
|
||||
MetadataFieldType,
|
||||
ExtractStrategy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlotDefinitionService:
|
||||
"""
|
||||
[AC-MRS-07, AC-MRS-08] 槽位定义服务
|
||||
|
||||
管理独立的槽位定义模型,与元数据字段解耦但可复用
|
||||
"""
|
||||
|
||||
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
|
||||
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
|
||||
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def list_slot_definitions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
required: bool | None = None,
|
||||
) -> list[SlotDefinition]:
|
||||
"""
|
||||
列出租户所有槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
required: 按是否必填过滤
|
||||
|
||||
Returns:
|
||||
SlotDefinition 列表
|
||||
"""
|
||||
stmt = select(SlotDefinition).where(
|
||||
SlotDefinition.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
if required is not None:
|
||||
stmt = stmt.where(SlotDefinition.required == required)
|
||||
|
||||
stmt = stmt.order_by(SlotDefinition.created_at.desc())
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_id: str,
|
||||
) -> SlotDefinition | None:
|
||||
"""
|
||||
获取单个槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_id: 槽位定义 ID
|
||||
|
||||
Returns:
|
||||
SlotDefinition 或 None
|
||||
"""
|
||||
try:
|
||||
slot_uuid = uuid.UUID(slot_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stmt = select(SlotDefinition).where(
|
||||
SlotDefinition.tenant_id == tenant_id,
|
||||
SlotDefinition.id == slot_uuid,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_slot_definition_by_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
) -> SlotDefinition | None:
|
||||
"""
|
||||
通过 slot_key 获取槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
SlotDefinition 或 None
|
||||
"""
|
||||
stmt = select(SlotDefinition).where(
|
||||
SlotDefinition.tenant_id == tenant_id,
|
||||
SlotDefinition.slot_key == slot_key,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_create: SlotDefinitionCreate,
|
||||
) -> SlotDefinition:
|
||||
"""
|
||||
[AC-MRS-07, AC-MRS-08] 创建槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_create: 创建数据
|
||||
|
||||
Returns:
|
||||
创建的 SlotDefinition
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 slot_key 已存在或参数无效
|
||||
"""
|
||||
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
|
||||
raise ValueError(
|
||||
f"slot_key '{slot_create.slot_key}' 格式不正确,"
|
||||
"必须以小写字母开头,仅允许小写字母、数字和下划线"
|
||||
)
|
||||
|
||||
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
|
||||
if existing:
|
||||
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
|
||||
|
||||
if slot_create.type not in self.VALID_TYPES:
|
||||
raise ValueError(
|
||||
f"无效的槽位类型 '{slot_create.type}',"
|
||||
f"有效类型为: {self.VALID_TYPES}"
|
||||
)
|
||||
|
||||
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
||||
raise ValueError(
|
||||
f"无效的提取策略 '{slot_create.extract_strategy}',"
|
||||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
||||
)
|
||||
|
||||
linked_field = None
|
||||
if slot_create.linked_field_id:
|
||||
linked_field = await self._get_linked_field(slot_create.linked_field_id)
|
||||
if not linked_field:
|
||||
raise ValueError(
|
||||
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
|
||||
)
|
||||
|
||||
slot = SlotDefinition(
|
||||
tenant_id=tenant_id,
|
||||
slot_key=slot_create.slot_key,
|
||||
type=slot_create.type,
|
||||
required=slot_create.required,
|
||||
extract_strategy=slot_create.extract_strategy,
|
||||
validation_rule=slot_create.validation_rule,
|
||||
ask_back_prompt=slot_create.ask_back_prompt,
|
||||
default_value=slot_create.default_value,
|
||||
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
|
||||
)
|
||||
|
||||
self._session.add(slot)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
|
||||
f"slot_key={slot.slot_key}, required={slot.required}, "
|
||||
f"linked_field_id={slot.linked_field_id}"
|
||||
)
|
||||
|
||||
return slot
|
||||
|
||||
async def update_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_id: str,
|
||||
slot_update: SlotDefinitionUpdate,
|
||||
) -> SlotDefinition | None:
|
||||
"""
|
||||
更新槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_id: 槽位定义 ID
|
||||
slot_update: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的 SlotDefinition 或 None
|
||||
"""
|
||||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||||
if not slot:
|
||||
return None
|
||||
|
||||
if slot_update.type is not None:
|
||||
if slot_update.type not in self.VALID_TYPES:
|
||||
raise ValueError(
|
||||
f"无效的槽位类型 '{slot_update.type}',"
|
||||
f"有效类型为: {self.VALID_TYPES}"
|
||||
)
|
||||
slot.type = slot_update.type
|
||||
|
||||
if slot_update.required is not None:
|
||||
slot.required = slot_update.required
|
||||
|
||||
if slot_update.extract_strategy is not None:
|
||||
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
||||
raise ValueError(
|
||||
f"无效的提取策略 '{slot_update.extract_strategy}',"
|
||||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
||||
)
|
||||
slot.extract_strategy = slot_update.extract_strategy
|
||||
|
||||
if slot_update.validation_rule is not None:
|
||||
slot.validation_rule = slot_update.validation_rule
|
||||
|
||||
if slot_update.ask_back_prompt is not None:
|
||||
slot.ask_back_prompt = slot_update.ask_back_prompt
|
||||
|
||||
if slot_update.default_value is not None:
|
||||
slot.default_value = slot_update.default_value
|
||||
|
||||
if slot_update.linked_field_id is not None:
|
||||
if slot_update.linked_field_id:
|
||||
linked_field = await self._get_linked_field(slot_update.linked_field_id)
|
||||
if not linked_field:
|
||||
raise ValueError(
|
||||
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
|
||||
)
|
||||
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
|
||||
else:
|
||||
slot.linked_field_id = None
|
||||
|
||||
slot.updated_at = datetime.utcnow()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
|
||||
f"slot_id={slot_id}"
|
||||
)
|
||||
|
||||
return slot
|
||||
|
||||
async def delete_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-MRS-16] 删除槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_id: 槽位定义 ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||||
if not slot:
|
||||
return False
|
||||
|
||||
await self._session.delete(slot)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
|
||||
f"slot_id={slot_id}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _get_linked_field(
|
||||
self,
|
||||
field_id: str,
|
||||
) -> MetadataFieldDefinition | None:
|
||||
"""
|
||||
获取关联的元数据字段
|
||||
|
||||
Args:
|
||||
field_id: 字段 ID
|
||||
|
||||
Returns:
|
||||
MetadataFieldDefinition 或 None
|
||||
"""
|
||||
try:
|
||||
field_uuid = uuid.UUID(field_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stmt = select(MetadataFieldDefinition).where(
|
||||
MetadataFieldDefinition.id == field_uuid,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_slot_definition_with_field(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
获取槽位定义及其关联字段信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_id: 槽位定义 ID
|
||||
|
||||
Returns:
|
||||
包含槽位定义和关联字段的字典
|
||||
"""
|
||||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||||
if not slot:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"id": str(slot.id),
|
||||
"tenant_id": slot.tenant_id,
|
||||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||||
"created_at": slot.created_at.isoformat() if slot.created_at else None,
|
||||
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
|
||||
"linked_field": None,
|
||||
}
|
||||
|
||||
if slot.linked_field_id:
|
||||
linked_field = await self._get_linked_field(str(slot.linked_field_id))
|
||||
if linked_field:
|
||||
result["linked_field"] = {
|
||||
"id": str(linked_field.id),
|
||||
"field_key": linked_field.field_key,
|
||||
"label": linked_field.label,
|
||||
"type": linked_field.type,
|
||||
"required": linked_field.required,
|
||||
"options": linked_field.options,
|
||||
"default_value": linked_field.default_value,
|
||||
"scope": linked_field.scope,
|
||||
"is_filterable": linked_field.is_filterable,
|
||||
"is_rank_feature": linked_field.is_rank_feature,
|
||||
"field_roles": linked_field.field_roles,
|
||||
"status": linked_field.status,
|
||||
}
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
Unit tests for RoleBasedFieldProvider service.
|
||||
[AC-MRS-04,05,10,11,12,13,14] 验证按角色查询字段定义功能
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.mid.role_based_field_provider import (
|
||||
RoleBasedFieldProvider,
|
||||
InvalidRoleError,
|
||||
)
|
||||
from app.models.entities import (
|
||||
MetadataFieldDefinition,
|
||||
MetadataFieldStatus,
|
||||
SlotDefinition,
|
||||
)
|
||||
from app.schemas.metadata import VALID_FIELD_ROLES
|
||||
|
||||
|
||||
class TestRoleBasedFieldProvider:
|
||||
"""[AC-MRS-04,05,10] RoleBasedFieldProvider 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock AsyncSession"""
|
||||
session = MagicMock(spec=AsyncSession)
|
||||
session.execute = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_session):
|
||||
"""Create provider instance"""
|
||||
return RoleBasedFieldProvider(mock_session)
|
||||
|
||||
def test_validate_role_valid(self, provider):
|
||||
"""[AC-MRS-04] 验证有效角色"""
|
||||
for role in VALID_FIELD_ROLES:
|
||||
result = provider._validate_role(role)
|
||||
assert result == role
|
||||
|
||||
def test_validate_role_invalid(self, provider):
|
||||
"""[AC-MRS-05] 验证无效角色抛出异常"""
|
||||
with pytest.raises(InvalidRoleError) as exc_info:
|
||||
provider._validate_role("invalid_role")
|
||||
|
||||
assert "Invalid role 'invalid_role'" in str(exc_info.value)
|
||||
assert exc_info.value.valid_roles == VALID_FIELD_ROLES
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fields_by_role(self, provider, mock_session):
|
||||
"""[AC-MRS-04] 按角色获取字段定义"""
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.id = "test-id"
|
||||
mock_field.field_key = "grade"
|
||||
mock_field.label = "年级"
|
||||
mock_field.field_roles = ["resource_filter", "slot"]
|
||||
mock_field.status = MetadataFieldStatus.ACTIVE.value
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_field]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
fields = await provider.get_fields_by_role(
|
||||
"test-tenant",
|
||||
"resource_filter"
|
||||
)
|
||||
|
||||
assert len(fields) == 1
|
||||
assert fields[0].field_key == "grade"
|
||||
assert "resource_filter" in fields[0].field_roles
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fields_by_role_invalid_role(self, provider):
|
||||
"""[AC-MRS-05] 无效角色返回 400 错误"""
|
||||
with pytest.raises(InvalidRoleError):
|
||||
await provider.get_fields_by_role(
|
||||
"test-tenant",
|
||||
"invalid_role"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fields_by_role_include_deprecated(self, provider, mock_session):
|
||||
"""[AC-MRS-04] 包含已废弃字段"""
|
||||
mock_active = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_active.field_key = "active_field"
|
||||
mock_active.status = MetadataFieldStatus.ACTIVE.value
|
||||
|
||||
mock_deprecated = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_deprecated.field_key = "deprecated_field"
|
||||
mock_deprecated.status = MetadataFieldStatus.DEPRECATED.value
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_active, mock_deprecated]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
fields = await provider.get_fields_by_role(
|
||||
"test-tenant",
|
||||
"resource_filter",
|
||||
include_deprecated=True
|
||||
)
|
||||
|
||||
assert len(fields) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slot_definitions_by_role(self, provider, mock_session):
|
||||
"""[AC-MRS-10] 按角色获取槽位定义"""
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.id = MagicMock()
|
||||
mock_field.id.__str__ = lambda self: "field-id-123"
|
||||
mock_field.field_key = "grade"
|
||||
mock_field.label = "年级"
|
||||
mock_field.type = "string"
|
||||
mock_field.required = True
|
||||
mock_field.options = None
|
||||
mock_field.default_value = None
|
||||
mock_field.scope = ["kb_document"]
|
||||
mock_field.is_filterable = True
|
||||
mock_field.is_rank_feature = False
|
||||
mock_field.field_roles = ["slot"]
|
||||
mock_field.status = MetadataFieldStatus.ACTIVE.value
|
||||
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.id = MagicMock()
|
||||
mock_slot.id.__str__ = lambda self: "slot-id-456"
|
||||
mock_slot.tenant_id = "test-tenant"
|
||||
mock_slot.slot_key = "grade"
|
||||
mock_slot.type = "string"
|
||||
mock_slot.required = True
|
||||
mock_slot.extract_strategy = "llm"
|
||||
mock_slot.validation_rule = None
|
||||
mock_slot.ask_back_prompt = "请输入年级"
|
||||
mock_slot.default_value = None
|
||||
mock_slot.linked_field_id = mock_field.id
|
||||
mock_slot.created_at = None
|
||||
mock_slot.updated_at = None
|
||||
|
||||
field_result = MagicMock()
|
||||
field_result.scalars.return_value.all.return_value = [mock_field]
|
||||
|
||||
slot_result = MagicMock()
|
||||
slot_result.scalars.return_value.all.return_value = [mock_slot]
|
||||
|
||||
mock_session.execute.side_effect = [field_result, slot_result]
|
||||
|
||||
slots = await provider.get_slot_definitions_by_role("test-tenant", "slot")
|
||||
|
||||
assert len(slots) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resource_filter_fields(self, provider, mock_session):
|
||||
"""[AC-MRS-11] 获取资源过滤角色字段"""
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.field_key = "category"
|
||||
mock_field.field_roles = ["resource_filter"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_field]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
fields = await provider.get_resource_filter_fields("test-tenant")
|
||||
|
||||
assert len(fields) == 1
|
||||
assert "resource_filter" in fields[0].field_roles
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slot_fields(self, provider, mock_session):
|
||||
"""[AC-MRS-12] 获取槽位角色字段"""
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.field_key = "user_name"
|
||||
mock_field.field_roles = ["slot"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_field]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
fields = await provider.get_slot_fields("test-tenant")
|
||||
|
||||
assert len(fields) == 1
|
||||
assert "slot" in fields[0].field_roles
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_routing_signal_fields(self, provider, mock_session):
|
||||
"""[AC-MRS-13] 获取路由信号角色字段"""
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.field_key = "priority"
|
||||
mock_field.field_roles = ["routing_signal"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_field]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
fields = await provider.get_routing_signal_fields("test-tenant")
|
||||
|
||||
assert len(fields) == 1
|
||||
assert "routing_signal" in fields[0].field_roles
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_prompt_var_fields(self, provider, mock_session):
|
||||
"""[AC-MRS-14] 获取提示词变量角色字段"""
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.field_key = "user_name"
|
||||
mock_field.field_roles = ["prompt_var"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_field]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
fields = await provider.get_prompt_var_fields("test-tenant")
|
||||
|
||||
assert len(fields) == 1
|
||||
assert "prompt_var" in fields[0].field_roles
|
||||
|
||||
|
||||
class TestInvalidRoleError:
|
||||
"""[AC-MRS-05] InvalidRoleError 测试"""
|
||||
|
||||
def test_error_message(self):
|
||||
"""验证错误消息格式"""
|
||||
error = InvalidRoleError("bad_role")
|
||||
|
||||
assert error.role == "bad_role"
|
||||
assert error.valid_roles == VALID_FIELD_ROLES
|
||||
assert "Invalid role 'bad_role'" in str(error)
|
||||
assert "resource_filter" in str(error)
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""
|
||||
Unit tests for SlotDefinitionService.
|
||||
[AC-MRS-07,08,16] 验证槽位定义管理功能
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
from app.models.entities import (
|
||||
SlotDefinition,
|
||||
SlotDefinitionCreate,
|
||||
SlotDefinitionUpdate,
|
||||
MetadataFieldDefinition,
|
||||
)
|
||||
|
||||
|
||||
class TestSlotDefinitionService:
|
||||
"""[AC-MRS-07,08,16] SlotDefinitionService 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock AsyncSession"""
|
||||
session = MagicMock(spec=AsyncSession)
|
||||
session.execute = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
"""Create service instance"""
|
||||
return SlotDefinitionService(mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_slot_definitions(self, service, mock_session):
|
||||
"""列出槽位定义"""
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.id = uuid.uuid4()
|
||||
mock_slot.slot_key = "grade"
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_slot]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slots = await service.list_slot_definitions("test-tenant")
|
||||
|
||||
assert len(slots) == 1
|
||||
assert slots[0].slot_key == "grade"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_slot_definitions_filter_required(self, service, mock_session):
|
||||
"""按必填过滤槽位定义"""
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.required = True
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_slot]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slots = await service.list_slot_definitions("test-tenant", required=True)
|
||||
|
||||
assert len(slots) == 1
|
||||
assert slots[0].required is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slot_definition(self, service, mock_session):
|
||||
"""获取单个槽位定义"""
|
||||
slot_id = uuid.uuid4()
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.id = slot_id
|
||||
mock_slot.slot_key = "grade"
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_slot
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot = await service.get_slot_definition("test-tenant", str(slot_id))
|
||||
|
||||
assert slot is not None
|
||||
assert slot.slot_key == "grade"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slot_definition_not_found(self, service, mock_session):
|
||||
"""获取不存在的槽位定义"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot = await service.get_slot_definition("test-tenant", str(uuid.uuid4()))
|
||||
|
||||
assert slot is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slot_definition_by_key(self, service, mock_session):
|
||||
"""通过 slot_key 获取槽位定义"""
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.slot_key = "grade"
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_slot
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot = await service.get_slot_definition_by_key("test-tenant", "grade")
|
||||
|
||||
assert slot is not None
|
||||
assert slot.slot_key == "grade"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_slot_definition(self, service, mock_session):
|
||||
"""[AC-MRS-07] 创建槽位定义"""
|
||||
slot_create = SlotDefinitionCreate(
|
||||
slot_key="grade",
|
||||
type="string",
|
||||
required=True,
|
||||
extract_strategy="llm",
|
||||
ask_back_prompt="请输入年级",
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot = await service.create_slot_definition("test-tenant", slot_create)
|
||||
|
||||
assert slot is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_slot_definition_invalid_key(self, service):
|
||||
"""[AC-MRS-07] 创建无效 slot_key 抛出异常"""
|
||||
slot_create = SlotDefinitionCreate(
|
||||
slot_key="InvalidKey",
|
||||
type="string",
|
||||
required=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.create_slot_definition("test-tenant", slot_create)
|
||||
|
||||
assert "格式不正确" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_slot_definition_duplicate_key(self, service, mock_session):
|
||||
"""[AC-MRS-07] 创建重复 slot_key 抛出异常"""
|
||||
existing_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_slot
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot_create = SlotDefinitionCreate(
|
||||
slot_key="grade",
|
||||
type="string",
|
||||
required=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.create_slot_definition("test-tenant", slot_create)
|
||||
|
||||
assert "已存在" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_slot_definition_invalid_type(self, service, mock_session):
|
||||
"""[AC-MRS-07] 创建无效类型抛出异常"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot_create = SlotDefinitionCreate(
|
||||
slot_key="grade",
|
||||
type="invalid_type",
|
||||
required=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.create_slot_definition("test-tenant", slot_create)
|
||||
|
||||
assert "无效的槽位类型" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_slot_definition_with_linked_field(self, service, mock_session):
|
||||
"""[AC-MRS-08] 创建槽位定义并关联元数据字段"""
|
||||
field_id = uuid.uuid4()
|
||||
mock_field = MagicMock(spec=MetadataFieldDefinition)
|
||||
mock_field.id = field_id
|
||||
|
||||
slot_result = MagicMock()
|
||||
slot_result.scalar_one_or_none.return_value = None
|
||||
|
||||
field_result = MagicMock()
|
||||
field_result.scalar_one_or_none.return_value = mock_field
|
||||
|
||||
mock_session.execute.side_effect = [slot_result, field_result]
|
||||
|
||||
slot_create = SlotDefinitionCreate(
|
||||
slot_key="grade",
|
||||
type="string",
|
||||
required=True,
|
||||
linked_field_id=str(field_id),
|
||||
)
|
||||
|
||||
slot = await service.create_slot_definition("test-tenant", slot_create)
|
||||
|
||||
assert slot is not None
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_slot_definition_linked_field_not_found(self, service, mock_session):
|
||||
"""[AC-MRS-08] 关联字段不存在抛出异常"""
|
||||
field_id = uuid.uuid4()
|
||||
|
||||
slot_result = MagicMock()
|
||||
slot_result.scalar_one_or_none.return_value = None
|
||||
|
||||
field_result = MagicMock()
|
||||
field_result.scalar_one_or_none.return_value = None
|
||||
|
||||
mock_session.execute.side_effect = [slot_result, field_result]
|
||||
|
||||
slot_create = SlotDefinitionCreate(
|
||||
slot_key="grade",
|
||||
type="string",
|
||||
required=True,
|
||||
linked_field_id=str(field_id),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.create_slot_definition("test-tenant", slot_create)
|
||||
|
||||
assert "关联的元数据字段" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_slot_definition(self, service, mock_session):
|
||||
"""更新槽位定义"""
|
||||
slot_id = uuid.uuid4()
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.id = slot_id
|
||||
mock_slot.slot_key = "grade"
|
||||
mock_slot.type = "string"
|
||||
mock_slot.required = False
|
||||
mock_slot.extract_strategy = None
|
||||
mock_slot.validation_rule = None
|
||||
mock_slot.ask_back_prompt = None
|
||||
mock_slot.default_value = None
|
||||
mock_slot.linked_field_id = None
|
||||
mock_slot.updated_at = None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_slot
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot_update = SlotDefinitionUpdate(
|
||||
required=True,
|
||||
ask_back_prompt="请输入年级",
|
||||
)
|
||||
|
||||
slot = await service.update_slot_definition("test-tenant", str(slot_id), slot_update)
|
||||
|
||||
assert slot is not None
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_slot_definition_not_found(self, service, mock_session):
|
||||
"""更新不存在的槽位定义"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
slot_update = SlotDefinitionUpdate(required=True)
|
||||
|
||||
slot = await service.update_slot_definition("test-tenant", str(uuid.uuid4()), slot_update)
|
||||
|
||||
assert slot is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_slot_definition(self, service, mock_session):
|
||||
"""[AC-MRS-16] 删除槽位定义"""
|
||||
slot_id = uuid.uuid4()
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.id = slot_id
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_slot
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
success = await service.delete_slot_definition("test-tenant", str(slot_id))
|
||||
|
||||
assert success is True
|
||||
mock_session.delete.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_slot_definition_not_found(self, service, mock_session):
|
||||
"""[AC-MRS-16] 删除不存在的槽位定义"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
success = await service.delete_slot_definition("test-tenant", str(uuid.uuid4()))
|
||||
|
||||
assert success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slot_definition_with_field(self, service, mock_session):
|
||||
"""获取槽位定义及关联字段信息"""
|
||||
slot_id = uuid.uuid4()
|
||||
mock_slot = MagicMock(spec=SlotDefinition)
|
||||
mock_slot.id = slot_id
|
||||
mock_slot.tenant_id = "test-tenant"
|
||||
mock_slot.slot_key = "grade"
|
||||
mock_slot.type = "string"
|
||||
mock_slot.required = True
|
||||
mock_slot.extract_strategy = "llm"
|
||||
mock_slot.validation_rule = None
|
||||
mock_slot.ask_back_prompt = "请输入年级"
|
||||
mock_slot.default_value = None
|
||||
mock_slot.linked_field_id = None
|
||||
mock_slot.created_at = None
|
||||
mock_slot.updated_at = None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_slot
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await service.get_slot_definition_with_field("test-tenant", str(slot_id))
|
||||
|
||||
assert result is not None
|
||||
assert result["slot_key"] == "grade"
|
||||
assert result["linked_field"] is None
|
||||
Loading…
Reference in New Issue