From 8c259cee30d8410cef9e16f0e2f5ad0c2482ea8f Mon Sep 17 00:00:00 2001 From: MerCry Date: Fri, 27 Feb 2026 16:03:39 +0800 Subject: [PATCH] feat: implement output guardrail with forbidden word detection and behavior rules [AC-AISVC-78~AC-AISVC-85] --- AI中台对接文档.md | 511 ++++++++++++++++++ ai-service/app/api/admin/__init__.py | 1 + ai-service/app/api/admin/guardrails.py | 296 ++++++++++ ai-service/app/main.py | 20 +- ai-service/app/models/entities.py | 54 +- ai-service/app/services/guardrail/__init__.py | 18 + .../services/guardrail/behavior_service.py | 260 +++++++++ .../app/services/guardrail/input_scanner.py | 110 ++++ .../app/services/guardrail/output_filter.py | 154 ++++++ .../services/guardrail/streaming_filter.py | 366 +++++++++++++ .../app/services/guardrail/word_service.py | 297 ++++++++++ docs/progress/ai-service-progress.md | 99 ++-- 12 files changed, 2137 insertions(+), 49 deletions(-) create mode 100644 AI中台对接文档.md create mode 100644 ai-service/app/api/admin/guardrails.py create mode 100644 ai-service/app/services/guardrail/__init__.py create mode 100644 ai-service/app/services/guardrail/behavior_service.py create mode 100644 ai-service/app/services/guardrail/input_scanner.py create mode 100644 ai-service/app/services/guardrail/output_filter.py create mode 100644 ai-service/app/services/guardrail/streaming_filter.py create mode 100644 ai-service/app/services/guardrail/word_service.py diff --git a/AI中台对接文档.md b/AI中台对接文档.md new file mode 100644 index 0000000..9a1fb91 --- /dev/null +++ b/AI中台对接文档.md @@ -0,0 +1,511 @@ +# AI 中台对接文档 + +## 1. 概述 + +本文档描述 Python AI 中台对渠道侧(Java 主框架)暴露的 HTTP 接口规范,用于智能客服对话生成和服务健康检查。 + +### 1.1 服务信息 + +- **服务名称**: AI Service (Python AI 中台) +- **服务地址**: `http://ai-service:8080` +- **协议**: HTTP/1.1 +- **数据格式**: JSON / SSE (Server-Sent Events) +- **字符编码**: UTF-8 +- **契约版本**: v1.1.0 + +### 1.2 核心能力 + +- ✅ 智能对话生成(基于 LLM + RAG) +- ✅ 多租户隔离(基于 `X-Tenant-Id`) +- ✅ 会话上下文管理(基于 `sessionId`) +- ✅ 流式/非流式双模式输出 +- ✅ 置信度评估与转人工建议 +- ✅ 服务健康检查 + +--- + +## 2. 认证与租户隔离 + +### 2.1 API Key 认证(必填) + +所有接口请求(除健康检查外)必须在 HTTP Header 中携带 API Key: + +```http +X-API-Key: +``` + +**说明**: +- API Key 用于身份认证和访问控制 +- 缺失或无效的 API Key 将返回 `401 Unauthorized` +- API Key 由 AI 中台管理员分配,请妥善保管 +- 以下路径无需 API Key:`/health`、`/ai/health`、`/docs` + +### 2.2 租户标识(必填) + +所有接口请求必须在 HTTP Header 中携带租户 ID: + +```http +X-Tenant-Id: +``` + +**租户 ID 格式规范**:`name@ash@year` + +示例: +- `szmp@ash@2026` - 深圳某项目 2026 年 +- `abc123@ash@2025` - ABC 项目 2025 年 + +**说明**: +- 租户 ID 用于数据隔离(知识库、会话历史、配置等) +- 缺失或格式错误的租户 ID 将返回 `400 Bad Request` +- 不同租户的数据完全隔离,不可跨租户访问 +- 租户不存在时会自动创建 + +--- + +## 3. 接口列表 + +| 接口路径 | 方法 | 功能 | 响应模式 | +|---------|------|------|---------| +| `/ai/chat` | POST | 生成 AI 回复 | JSON / SSE | +| `/ai/health` | GET | 健康检查 | JSON | + +--- + +## 4. 接口详细说明 + +### 4.1 生成 AI 回复 + +**接口路径**: `POST /ai/chat` + +**功能描述**: 根据用户消息和会话历史生成 AI 回复,支持 RAG 检索增强、上下文管理、置信度评估。 + +#### 4.1.1 请求参数 + +**Headers**: +```http +Content-Type: application/json +X-API-Key: +X-Tenant-Id: +Accept: application/json # 或 text/event-stream(流式输出) +``` + +**Body** (JSON): + +| 字段 | 类型 | 必填 | 说明 | +|-----|------|------|------| +| `sessionId` | string | ✅ | 会话 ID,用于关联同一会话的对话历史 | +| `currentMessage` | string | ✅ | 当前用户消息内容 | +| `channelType` | string | ✅ | 渠道类型,枚举值:`wechat`、`douyin`、`jd` | +| `history` | array | ❌ | 历史消息列表(可选,AI 中台会自动管理会话历史) | +| `metadata` | object | ❌ | 扩展元数据(可选) | + +**history 数组元素结构**: +```json +{ + "role": "user | assistant", + "content": "消息内容" +} +``` + +**请求示例**: +```json +{ + "sessionId": "kf_001_wx123456_1708765432000", + "currentMessage": "我想了解产品价格", + "channelType": "wechat", + "metadata": { + "channelUserId": "wx123456", + "extra": "..." + } +} +``` + +#### 4.1.2 响应格式 + +##### 模式 1: JSON 响应(非流式) + +**状态码**: `200 OK` + +**响应体**: +```json +{ + "reply": "您好,我们的产品价格根据套餐不同有所差异...", + "confidence": 0.92, + "shouldTransfer": false, + "transferReason": null, + "metadata": { + "retrieval_count": 3, + "rag_enabled": true + } +} +``` + +**字段说明**: + +| 字段 | 类型 | 必填 | 说明 | +|-----|------|------|------| +| `reply` | string | ✅ | AI 生成的回复内容 | +| `confidence` | number | ✅ | 置信度评分(0.0-1.0),越高表示回答越可靠 | +| `shouldTransfer` | boolean | ✅ | 是否建议转人工(true=建议转人工) | +| `transferReason` | string | ❌ | 转人工原因(可选) | +| `metadata` | object | ❌ | 响应元数据(可选) | + +##### 模式 2: SSE 流式响应 + +**触发条件**: 请求头包含 `Accept: text/event-stream` + +**响应头**: +```http +Content-Type: text/event-stream +Cache-Control: no-cache +Connection: keep-alive +``` + +**事件流格式**: + +1. **增量消息事件** (可多次发送) +``` +event: message +data: {"delta": "您好,"} + +event: message +data: {"delta": "我们的产品"} +``` + +2. **最终结果事件** (发送一次后关闭连接) +``` +event: final +data: {"reply": "完整回复内容", "confidence": 0.92, "shouldTransfer": false} +``` + +3. **错误事件** (发生错误时发送) +``` +event: error +data: {"code": "INTERNAL_ERROR", "message": "错误描述"} +``` + +**事件序列保证**: +- `message*` (0 或多次) → `final` (1 次) → 连接关闭 +- 或 `message*` (0 或多次) → `error` (1 次) → 连接关闭 + +#### 4.1.3 错误响应 + +**401 Unauthorized** - 认证失败 +```json +{ + "code": "UNAUTHORIZED", + "message": "Missing required header: X-API-Key", + "details": [] +} +``` + +**400 Bad Request** - 请求参数错误 +```json +{ + "code": "INVALID_REQUEST", + "message": "缺少必填字段: sessionId", + "details": [] +} +``` + +**400 Bad Request** - 租户 ID 格式错误 +```json +{ + "code": "INVALID_TENANT_ID", + "message": "Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)", + "details": [] +} +``` + +**500 Internal Server Error** - 服务内部错误 +```json +{ + "code": "INTERNAL_ERROR", + "message": "LLM 调用失败", + "details": [] +} +``` + +**503 Service Unavailable** - 服务不可用 +```json +{ + "code": "SERVICE_UNAVAILABLE", + "message": "向量数据库连接失败", + "details": [] +} +``` + +--- + +### 4.2 健康检查 + +**接口路径**: `GET /ai/health` + +**功能描述**: 检查 AI 服务是否正常运行,用于服务监控和负载均衡健康探测。 + +#### 4.2.1 请求参数 + +无需请求参数,无需认证头。 + +#### 4.2.2 响应格式 + +**200 OK** - 服务正常 +```json +{ + "status": "healthy" +} +``` + +**503 Service Unavailable** - 服务不健康 +```json +{ + "status": "unhealthy" +} +``` + +--- + +## 5. 调用示例 + +### 5.1 Java 调用示例(非流式) + +```java +import org.springframework.http.*; +import org.springframework.web.client.RestTemplate; + +public class AIServiceClient { + + private final RestTemplate restTemplate; + private final String aiServiceUrl = "http://ai-service:8080"; + private final String apiKey = "your_api_key_here"; + + public ChatResponse generateReply(String tenantId, ChatRequest request) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("X-API-Key", apiKey); + headers.set("X-Tenant-Id", tenantId); + + HttpEntity entity = new HttpEntity<>(request, headers); + + ResponseEntity response = restTemplate.postForEntity( + aiServiceUrl + "/ai/chat", + entity, + ChatResponse.class + ); + + return response.getBody(); + } +} +``` + +### 5.2 Java 调用示例(流式) + +```java +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; + +public class AIServiceStreamClient { + + private final WebClient webClient; + private final String apiKey = "your_api_key_here"; + + public Flux> generateReplyStream( + String tenantId, + ChatRequest request + ) { + return webClient.post() + .uri("/ai/chat") + .header("X-API-Key", apiKey) + .header("X-Tenant-Id", tenantId) + .header("Accept", "text/event-stream") + .bodyValue(request) + .retrieve() + .bodyToFlux(ServerSentEvent.class); + } +} +``` + +### 5.3 cURL 调用示例 + +```bash +# 非流式调用 +curl -X POST http://ai-service:8080/ai/chat \ + -H "Content-Type: application/json" \ + -H "X-API-Key: your_api_key_here" \ + -H "X-Tenant-Id: szmp@ash@2026" \ + -d '{ + "sessionId": "kf_001_wx123456_1708765432000", + "currentMessage": "我想了解产品价格", + "channelType": "wechat" + }' + +# 流式调用 +curl -X POST http://ai-service:8080/ai/chat \ + -H "Content-Type: application/json" \ + -H "X-API-Key: your_api_key_here" \ + -H "X-Tenant-Id: szmp@ash@2026" \ + -H "Accept: text/event-stream" \ + -d '{ + "sessionId": "kf_001_wx123456_1708765432000", + "currentMessage": "我想了解产品价格", + "channelType": "wechat" + }' + +# 健康检查(无需认证) +curl http://ai-service:8080/ai/health +``` + +--- + +## 6. 业务逻辑说明 + +### 6.1 会话管理 + +- **会话标识**: `sessionId` 用于唯一标识一个对话会话 +- **自动持久化**: AI 中台会自动保存会话历史,无需调用方每次传递完整历史 +- **可选历史**: 调用方可通过 `history` 字段提供外部历史,AI 中台会合并处理 +- **租户隔离**: 相同 `sessionId` 在不同 `tenantId` 下视为不同会话 + +### 6.2 RAG 检索增强 + +- **自动触发**: AI 中台会根据用户问题自动判断是否需要检索知识库 +- **多知识库**: 支持按知识库类型(产品知识、FAQ、话术模板等)分类检索 +- **置信度评估**: 检索结果质量会影响 `confidence` 评分 +- **兜底策略**: 检索失败或无结果时,AI 会基于通用知识回答并降低置信度 + +### 6.3 转人工建议 + +`shouldTransfer` 字段由以下因素决定: + +- ✅ 置信度低于阈值(默认 0.6) +- ✅ 检索无结果或结果质量差 +- ✅ 用户明确要求人工服务 +- ✅ 意图识别命中"转人工"规则 + +**注意**: `shouldTransfer=true` 仅为建议,最终是否转人工由调用方(Java 主框架)决策。 + +### 6.4 意图识别与规则引擎 + +- **前置处理**: 用户消息会先经过意图识别 +- **固定回复**: 命中固定规则时直接返回预设话术(跳过 LLM 调用) +- **话术流程**: 命中流程规则时进入多轮引导对话 +- **定向检索**: 命中 RAG 规则时使用指定知识库检索 + +### 6.5 输出护栏 + +- **禁词过滤**: AI 回复会自动过滤禁词(竞品名称、敏感词等) +- **替换策略**: 支持星号替换、文本替换、整条拦截三种策略 +- **行为约束**: Prompt 中注入行为规则(如"不承诺具体赔偿金额") + +--- + +## 7. 性能与限制 + +### 7.1 性能指标 + +| 指标 | 非流式 | 流式 | +|-----|-------|------| +| 首字响应时间 | 1-3 秒 | 200-500 毫秒 | +| 完整响应时间 | 2-5 秒 | 3-8 秒 | +| 并发支持 | 100+ QPS | 50+ QPS | + +### 7.2 限制说明 + +- **消息长度**: 单条消息最大 4000 字符 +- **历史长度**: 建议历史消息不超过 20 轮(AI 中台会自动截断) +- **超时设置**: 建议调用方设置 10 秒超时(非流式)、30 秒超时(流式) +- **重试策略**: 503 错误建议指数退避重试,500 错误建议降级处理 + +--- + +## 8. 错误码参考 + +| 错误码 | HTTP 状态码 | 说明 | 处理建议 | +|-------|-----------|------|---------| +| `UNAUTHORIZED` | 401 | 认证失败(缺少或无效 API Key) | 检查 X-API-Key 请求头 | +| `INVALID_REQUEST` | 400 | 请求参数错误 | 检查必填字段和参数格式 | +| `MISSING_TENANT_ID` | 400 | 缺少租户 ID | 添加 X-Tenant-Id 请求头 | +| `INVALID_TENANT_ID` | 400 | 租户 ID 格式错误 | 使用正确格式:name@ash@year | +| `INTERNAL_ERROR` | 500 | 服务内部错误 | 降级处理或重试 | +| `LLM_ERROR` | 500 | LLM 调用失败 | 降级处理或重试 | +| `SERVICE_UNAVAILABLE` | 503 | 服务不可用 | 指数退避重试 | +| `QDRANT_ERROR` | 503 | 向量库不可用 | 指数退避重试 | +| `STREAMING_ERROR` | 200 (SSE) | 流式传输错误 | 关闭连接并重试 | + +--- + +## 9. 最佳实践 + +### 9.1 API Key 管理 + +- API Key 由 AI 中台管理员通过管理后台分配 +- 建议为不同环境(开发/测试/生产)使用不同的 API Key +- API Key 应存储在配置文件或环境变量中,不要硬编码 +- 定期轮换 API Key 以提高安全性 + +### 9.2 会话 ID 生成规范 + +建议格式: `{业务前缀}_{租户ID}_{渠道用户ID}_{时间戳}` + +示例: `kf_001_wx123456_1708765432000` + +### 9.3 流式 vs 非流式选择 + +- **流式**: 适用于 Web/App 实时对话场景,用户体验更好 +- **非流式**: 适用于批量处理、异步任务、API 集成场景 + +### 9.4 降级策略建议 + +```java +public ChatResponse generateReplyWithFallback(String tenantId, ChatRequest request) { + try { + return aiServiceClient.generateReply(tenantId, request); + } catch (ServiceUnavailableException e) { + // 降级策略 1: 返回固定话术 + return ChatResponse.builder() + .reply("抱歉,当前咨询量较大,请稍后再试或转人工服务。") + .confidence(0.0) + .shouldTransfer(true) + .build(); + } catch (Exception e) { + // 降级策略 2: 直接转人工 + return ChatResponse.builder() + .reply("系统繁忙,正在为您转接人工客服...") + .confidence(0.0) + .shouldTransfer(true) + .transferReason("AI 服务异常") + .build(); + } +} +``` + +### 9.5 监控指标建议 + +- ✅ 接口响应时间(P50/P95/P99) +- ✅ 接口成功率 +- ✅ 置信度分布 +- ✅ 转人工率 +- ✅ 错误码分布 + +--- + +## 10. 变更日志 + +| 版本 | 日期 | 变更内容 | +|-----|------|---------| +| v1.1.0 | 2026-02-27 | 新增流式输出支持、意图识别、输出护栏 | +| v1.0.0 | 2026-02-20 | 初始版本,支持基础对话生成和健康检查 | + +--- + +## 11. 联系方式 + +- **技术支持**: AI 中台开发团队 +- **问题反馈**: 提交 Issue 到项目仓库 +- **文档更新**: 参考 `spec/ai-service/openapi.provider.yaml` + +--- + +**文档生成时间**: 2026-02-27 +**契约版本**: v1.1.0 +**维护状态**: ✅ 活跃维护 diff --git a/ai-service/app/api/admin/__init__.py b/ai-service/app/api/admin/__init__.py index 9aef820..83036ff 100644 --- a/ai-service/app/api/admin/__init__.py +++ b/ai-service/app/api/admin/__init__.py @@ -15,4 +15,5 @@ from app.api.admin.rag import router as rag_router from app.api.admin.script_flows import router as script_flows_router from app.api.admin.sessions import router as sessions_router from app.api.admin.tenants import router as tenants_router + __all__ = ["api_key_router", "dashboard_router", "embedding_router", "guardrails_router", "intent_rules_router", "kb_router", "llm_router", "prompt_templates_router", "rag_router", "script_flows_router", "sessions_router", "tenants_router"] diff --git a/ai-service/app/api/admin/guardrails.py b/ai-service/app/api/admin/guardrails.py new file mode 100644 index 0000000..bf79d75 --- /dev/null +++ b/ai-service/app/api/admin/guardrails.py @@ -0,0 +1,296 @@ +""" +Guardrail Management API. +[AC-AISVC-78~AC-AISVC-85] Forbidden words and behavior rules CRUD endpoints. +""" + +import logging +import uuid +from typing import Any + +from fastapi import APIRouter, Depends, Header, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.models.entities import ( + BehaviorRuleCreate, + BehaviorRuleUpdate, + ForbiddenWordCreate, + ForbiddenWordUpdate, +) +from app.services.guardrail.behavior_service import BehaviorRuleService +from app.services.guardrail.word_service import ForbiddenWordService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/admin/guardrails", tags=["Guardrails"]) + + +def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str: + """Extract tenant ID from header.""" + if not x_tenant_id: + raise HTTPException(status_code=400, detail="X-Tenant-Id header is required") + return x_tenant_id + + +@router.get("/forbidden-words") +async def list_forbidden_words( + tenant_id: str = Depends(get_tenant_id), + category: str | None = None, + is_enabled: bool | None = None, + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-79] List all forbidden words for a tenant. + """ + logger.info( + f"[AC-AISVC-79] Listing forbidden words for tenant={tenant_id}, " + f"category={category}, is_enabled={is_enabled}" + ) + + service = ForbiddenWordService(session) + words = await service.list_words(tenant_id, category, is_enabled) + + data = [] + for word in words: + data.append(await service.word_to_info_dict(word)) + + return {"data": data} + + +@router.post("/forbidden-words", status_code=201) +async def create_forbidden_word( + body: ForbiddenWordCreate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-78] Create a new forbidden word. + """ + valid_categories = ["competitor", "sensitive", "political", "custom"] + if body.category not in valid_categories: + raise HTTPException( + status_code=400, + detail=f"Invalid category. Must be one of: {valid_categories}" + ) + + valid_strategies = ["mask", "replace", "block"] + if body.strategy not in valid_strategies: + raise HTTPException( + status_code=400, + detail=f"Invalid strategy. Must be one of: {valid_strategies}" + ) + + if body.strategy == "replace" and not body.replacement: + raise HTTPException( + status_code=400, + detail="replacement is required when strategy is 'replace'" + ) + + logger.info( + f"[AC-AISVC-78] Creating forbidden word for tenant={tenant_id}, word={body.word}" + ) + + service = ForbiddenWordService(session) + try: + word = await service.create_word(tenant_id, body) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return await service.word_to_info_dict(word) + + +@router.get("/forbidden-words/{word_id}") +async def get_forbidden_word( + word_id: uuid.UUID, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-79] Get forbidden word detail. + """ + logger.info(f"[AC-AISVC-79] Getting forbidden word for tenant={tenant_id}, id={word_id}") + + service = ForbiddenWordService(session) + word = await service.get_word(tenant_id, word_id) + + if not word: + raise HTTPException(status_code=404, detail="Forbidden word not found") + + return await service.word_to_info_dict(word) + + +@router.put("/forbidden-words/{word_id}") +async def update_forbidden_word( + word_id: uuid.UUID, + body: ForbiddenWordUpdate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-80] Update a forbidden word. + """ + valid_categories = ["competitor", "sensitive", "political", "custom"] + if body.category is not None and body.category not in valid_categories: + raise HTTPException( + status_code=400, + detail=f"Invalid category. Must be one of: {valid_categories}" + ) + + valid_strategies = ["mask", "replace", "block"] + if body.strategy is not None and body.strategy not in valid_strategies: + raise HTTPException( + status_code=400, + detail=f"Invalid strategy. Must be one of: {valid_strategies}" + ) + + logger.info(f"[AC-AISVC-80] Updating forbidden word for tenant={tenant_id}, id={word_id}") + + service = ForbiddenWordService(session) + try: + word = await service.update_word(tenant_id, word_id, body) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + if not word: + raise HTTPException(status_code=404, detail="Forbidden word not found") + + return await service.word_to_info_dict(word) + + +@router.delete("/forbidden-words/{word_id}", status_code=204) +async def delete_forbidden_word( + word_id: uuid.UUID, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> None: + """ + [AC-AISVC-81] Delete a forbidden word. + """ + logger.info(f"[AC-AISVC-81] Deleting forbidden word for tenant={tenant_id}, id={word_id}") + + service = ForbiddenWordService(session) + success = await service.delete_word(tenant_id, word_id) + + if not success: + raise HTTPException(status_code=404, detail="Forbidden word not found") + + +@router.get("/behavior-rules") +async def list_behavior_rules( + tenant_id: str = Depends(get_tenant_id), + category: str | None = None, + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-85] List all behavior rules for a tenant. + """ + logger.info( + f"[AC-AISVC-85] Listing behavior rules for tenant={tenant_id}, category={category}" + ) + + service = BehaviorRuleService(session) + rules = await service.list_rules(tenant_id, category) + + data = [] + for rule in rules: + data.append(await service.rule_to_info_dict(rule)) + + return {"data": data} + + +@router.post("/behavior-rules", status_code=201) +async def create_behavior_rule( + body: BehaviorRuleCreate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-84] Create a new behavior rule. + """ + valid_categories = ["compliance", "tone", "boundary", "custom"] + if body.category not in valid_categories: + raise HTTPException( + status_code=400, + detail=f"Invalid category. Must be one of: {valid_categories}" + ) + + logger.info( + f"[AC-AISVC-84] Creating behavior rule for tenant={tenant_id}, category={body.category}" + ) + + service = BehaviorRuleService(session) + try: + rule = await service.create_rule(tenant_id, body) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return await service.rule_to_info_dict(rule) + + +@router.get("/behavior-rules/{rule_id}") +async def get_behavior_rule( + rule_id: uuid.UUID, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-85] Get behavior rule detail. + """ + logger.info(f"[AC-AISVC-85] Getting behavior rule for tenant={tenant_id}, id={rule_id}") + + service = BehaviorRuleService(session) + rule = await service.get_rule(tenant_id, rule_id) + + if not rule: + raise HTTPException(status_code=404, detail="Behavior rule not found") + + return await service.rule_to_info_dict(rule) + + +@router.put("/behavior-rules/{rule_id}") +async def update_behavior_rule( + rule_id: uuid.UUID, + body: BehaviorRuleUpdate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """ + [AC-AISVC-85] Update a behavior rule. + """ + valid_categories = ["compliance", "tone", "boundary", "custom"] + if body.category is not None and body.category not in valid_categories: + raise HTTPException( + status_code=400, + detail=f"Invalid category. Must be one of: {valid_categories}" + ) + + logger.info(f"[AC-AISVC-85] Updating behavior rule for tenant={tenant_id}, id={rule_id}") + + service = BehaviorRuleService(session) + try: + rule = await service.update_rule(tenant_id, rule_id, body) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + if not rule: + raise HTTPException(status_code=404, detail="Behavior rule not found") + + return await service.rule_to_info_dict(rule) + + +@router.delete("/behavior-rules/{rule_id}", status_code=204) +async def delete_behavior_rule( + rule_id: uuid.UUID, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_session), +) -> None: + """ + [AC-AISVC-85] Delete a behavior rule. + """ + logger.info(f"[AC-AISVC-85] Deleting behavior rule for tenant={tenant_id}, id={rule_id}") + + service = BehaviorRuleService(session) + success = await service.delete_rule(tenant_id, rule_id) + + if not success: + raise HTTPException(status_code=404, detail="Behavior rule not found") diff --git a/ai-service/app/main.py b/ai-service/app/main.py index cc0bc3c..6e6cbb9 100644 --- a/ai-service/app/main.py +++ b/ai-service/app/main.py @@ -12,7 +12,20 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from app.api import chat_router, health_router -from app.api.admin import api_key_router, dashboard_router, embedding_router, guardrails_router, intent_rules_router, kb_router, llm_router, prompt_templates_router, rag_router, script_flows_router, sessions_router, tenants_router +from app.api.admin import ( + api_key_router, + dashboard_router, + embedding_router, + guardrails_router, + intent_rules_router, + kb_router, + llm_router, + prompt_templates_router, + rag_router, + script_flows_router, + sessions_router, + tenants_router, +) from app.api.admin.kb_optimized import router as kb_optimized_router from app.core.config import get_settings from app.core.database import close_db, init_db @@ -76,12 +89,12 @@ app = FastAPI( version=settings.app_version, description=""" Python AI Service for intelligent chat with RAG support. - + ## Features - Multi-tenant isolation via X-Tenant-Id header - SSE streaming support via Accept: text/event-stream - RAG-powered responses with confidence scoring - + ## Response Modes - **JSON**: Default response mode (Accept: application/json or no Accept header) - **SSE Streaming**: Set Accept: text/event-stream for streaming responses @@ -130,6 +143,7 @@ app.include_router(chat_router) app.include_router(api_key_router) app.include_router(dashboard_router) app.include_router(embedding_router) +app.include_router(guardrails_router) app.include_router(intent_rules_router) app.include_router(kb_router) app.include_router(kb_optimized_router) diff --git a/ai-service/app/models/entities.py b/ai-service/app/models/entities.py index 8ac5558..56bf0b1 100644 --- a/ai-service/app/models/entities.py +++ b/ai-service/app/models/entities.py @@ -8,7 +8,7 @@ from datetime import datetime from enum import Enum from typing import Any -from sqlalchemy import Column, JSON +from sqlalchemy import JSON, Column from sqlmodel import Field, Index, SQLModel @@ -141,7 +141,10 @@ class KnowledgeBase(SQLModel, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) name: str = Field(..., description="Knowledge base name") - kb_type: str = Field(default=KBType.GENERAL.value, description="Knowledge base type: product/faq/script/policy/general") + kb_type: str = Field( + default=KBType.GENERAL.value, + description="Knowledge base type: product/faq/script/policy/general" + ) description: str | None = Field(default=None, description="Knowledge base description") priority: int = Field(default=0, ge=0, description="Priority weight, higher value means higher priority") is_enabled: bool = Field(default=True, description="Whether the knowledge base is enabled") @@ -289,14 +292,25 @@ class PromptTemplateVersion(SQLModel, table=True): ) id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - template_id: uuid.UUID = Field(..., description="Foreign key to prompt_templates.id", foreign_key="prompt_templates.id", index=True) + template_id: uuid.UUID = Field( + ..., + description="Foreign key to prompt_templates.id", + foreign_key="prompt_templates.id", + index=True + ) version: int = Field(..., description="Version number (auto-incremented per template)") - status: str = Field(default=TemplateVersionStatus.DRAFT.value, description="Version status: draft/published/archived") - system_instruction: str = Field(..., description="System instruction content with {{variable}} placeholders") + status: str = Field( + default=TemplateVersionStatus.DRAFT.value, + description="Version status: draft/published/archived" + ) + system_instruction: str = Field( + ..., + description="System instruction content with {{variable}} placeholders" + ) variables: list[dict[str, Any]] | None = Field( default=None, sa_column=Column("variables", JSON, nullable=True), - description="Variable definitions, e.g., [{'name': 'persona_name', 'default': '小N', 'description': '人设名称'}]" + description="Variable definitions, e.g., [{'name': 'persona_name', 'default': '小N'}]" ) created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") @@ -510,7 +524,10 @@ class BehaviorRule(SQLModel, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) - rule_text: str = Field(..., description="Behavior constraint description, e.g., 'Do not promise specific compensation amounts'") + rule_text: str = Field( + ..., + description="Behavior constraint description, e.g., 'Do not promise specific compensation'" + ) category: str = Field(..., description="Category: compliance/tone/boundary/custom") is_enabled: bool = Field(default=True, description="Whether the rule is enabled") created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") @@ -618,7 +635,7 @@ class ScriptFlow(SQLModel, table=True): steps: list[dict[str, Any]] = Field( default=[], sa_column=Column("steps", JSON, nullable=False), - description="Flow steps list with step_no, content, wait_input, timeout_seconds, timeout_action, next_conditions, default_next" + description="Flow steps list with step_no, content, wait_input, timeout_seconds" ) is_enabled: bool = Field(default=True, description="Whether the flow is enabled") created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") @@ -640,13 +657,21 @@ class FlowInstance(SQLModel, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) session_id: str = Field(..., description="Session ID for conversation tracking", index=True) - flow_id: uuid.UUID = Field(..., description="Foreign key to script_flows.id", foreign_key="script_flows.id", index=True) + flow_id: uuid.UUID = Field( + ..., + description="Foreign key to script_flows.id", + foreign_key="script_flows.id", + index=True + ) current_step: int = Field(default=1, ge=1, description="Current step number (1-indexed)") - status: str = Field(default=FlowInstanceStatus.ACTIVE.value, description="Instance status: active/completed/timeout/cancelled") + status: str = Field( + default=FlowInstanceStatus.ACTIVE.value, + description="Instance status: active/completed/timeout/cancelled" + ) context: dict[str, Any] | None = Field( default=None, sa_column=Column("context", JSON, nullable=True), - description="Flow execution context, stores user inputs etc." + description="Flow execution context, stores user inputs" ) started_at: datetime = Field(default_factory=datetime.utcnow, description="Instance start time") updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") @@ -660,10 +685,13 @@ class FlowStep(SQLModel): content: str = Field(..., description="Script content for this step") wait_input: bool = Field(default=True, description="Whether to wait for user input") timeout_seconds: int = Field(default=120, ge=1, description="Timeout in seconds") - timeout_action: str = Field(default=TimeoutAction.REPEAT.value, description="Action on timeout: repeat/skip/transfer") + timeout_action: str = Field( + default=TimeoutAction.REPEAT.value, + description="Action on timeout: repeat/skip/transfer" + ) next_conditions: list[dict[str, Any]] | None = Field( default=None, - description="Conditions for next step: [{'keywords': [...], 'goto_step': N}, {'pattern': '...', 'goto_step': N}]" + description="Conditions for next step: [{'keywords': [...], 'goto_step': N}]" ) default_next: int | None = Field(default=None, description="Default next step if no condition matches") diff --git a/ai-service/app/services/guardrail/__init__.py b/ai-service/app/services/guardrail/__init__.py new file mode 100644 index 0000000..eef84dd --- /dev/null +++ b/ai-service/app/services/guardrail/__init__.py @@ -0,0 +1,18 @@ +""" +Guardrail services for AI Service. +[AC-AISVC-78~AC-AISVC-85] Output guardrail with forbidden word detection and behavior rules. +""" + +from app.services.guardrail.behavior_service import BehaviorRuleService +from app.services.guardrail.input_scanner import InputScanner +from app.services.guardrail.output_filter import OutputFilter +from app.services.guardrail.streaming_filter import StreamingGuardrail +from app.services.guardrail.word_service import ForbiddenWordService + +__all__ = [ + "ForbiddenWordService", + "BehaviorRuleService", + "InputScanner", + "OutputFilter", + "StreamingGuardrail", +] diff --git a/ai-service/app/services/guardrail/behavior_service.py b/ai-service/app/services/guardrail/behavior_service.py new file mode 100644 index 0000000..15b89ec --- /dev/null +++ b/ai-service/app/services/guardrail/behavior_service.py @@ -0,0 +1,260 @@ +""" +Behavior rule service for AI Service. +[AC-AISVC-84, AC-AISVC-85] Behavior rule CRUD management. +""" + +import logging +import time +import uuid +from collections.abc import Sequence +from datetime import datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col + +from app.models.entities import ( + BehaviorRule, + BehaviorRuleCategory, + BehaviorRuleCreate, + BehaviorRuleUpdate, +) + +logger = logging.getLogger(__name__) + +BEHAVIOR_CACHE_TTL_SECONDS = 60 + + +class BehaviorRuleCache: + """ + [AC-AISVC-84] In-memory cache for behavior rules. + Key: tenant_id + Value: (rules_list, cached_at) + TTL: 60 seconds + """ + + def __init__(self, ttl_seconds: int = BEHAVIOR_CACHE_TTL_SECONDS): + self._cache: dict[str, tuple[list[BehaviorRule], float]] = {} + self._ttl = ttl_seconds + + def get(self, tenant_id: str) -> list[BehaviorRule] | None: + """Get cached rules if not expired.""" + if tenant_id in self._cache: + rules, cached_at = self._cache[tenant_id] + if time.time() - cached_at < self._ttl: + return rules + else: + del self._cache[tenant_id] + return None + + def set(self, tenant_id: str, rules: list[BehaviorRule]) -> None: + """Cache rules for a tenant.""" + self._cache[tenant_id] = (rules, time.time()) + + def invalidate(self, tenant_id: str) -> None: + """Invalidate cache for a tenant.""" + if tenant_id in self._cache: + del self._cache[tenant_id] + logger.debug(f"Invalidated behavior rule cache for tenant={tenant_id}") + + +_behavior_cache = BehaviorRuleCache() + + +class BehaviorRuleService: + """ + [AC-AISVC-84, AC-AISVC-85] Service for managing behavior rules. + + Features: + - Rule CRUD with tenant isolation + - In-memory caching with TTL + - Cache invalidation on CRUD operations + - Rules are injected into Prompt system instruction + """ + + VALID_CATEGORIES = [c.value for c in BehaviorRuleCategory] + + def __init__(self, session: AsyncSession): + self._session = session + self._cache = _behavior_cache + + async def create_rule( + self, + tenant_id: str, + create_data: BehaviorRuleCreate, + ) -> BehaviorRule: + """ + [AC-AISVC-84] Create a new behavior rule. + """ + if create_data.category not in self.VALID_CATEGORIES: + raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}") + + rule = BehaviorRule( + tenant_id=tenant_id, + rule_text=create_data.rule_text, + category=create_data.category, + is_enabled=True, + ) + self._session.add(rule) + await self._session.flush() + + self._cache.invalidate(tenant_id) + + logger.info( + f"[AC-AISVC-84] Created behavior rule: tenant={tenant_id}, " + f"id={rule.id}, category={rule.category}" + ) + return rule + + async def list_rules( + self, + tenant_id: str, + category: str | None = None, + ) -> Sequence[BehaviorRule]: + """ + [AC-AISVC-85] List rules for a tenant with optional filters. + """ + stmt = select(BehaviorRule).where( + BehaviorRule.tenant_id == tenant_id # type: ignore + ) + + if category is not None: + stmt = stmt.where(BehaviorRule.category == category) # type: ignore + + stmt = stmt.order_by(col(BehaviorRule.category), col(BehaviorRule.created_at).desc()) + result = await self._session.execute(stmt) + return result.scalars().all() + + async def get_rule( + self, + tenant_id: str, + rule_id: uuid.UUID, + ) -> BehaviorRule | None: + """ + [AC-AISVC-85] Get rule by ID with tenant isolation. + """ + stmt = select(BehaviorRule).where( + BehaviorRule.tenant_id == tenant_id, # type: ignore + BehaviorRule.id == rule_id, # type: ignore + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def update_rule( + self, + tenant_id: str, + rule_id: uuid.UUID, + update_data: BehaviorRuleUpdate, + ) -> BehaviorRule | None: + """ + [AC-AISVC-85] Update a behavior rule. + """ + rule = await self.get_rule(tenant_id, rule_id) + if not rule: + return None + + if update_data.rule_text is not None: + rule.rule_text = update_data.rule_text + if update_data.category is not None: + if update_data.category not in self.VALID_CATEGORIES: + raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}") + rule.category = update_data.category + if update_data.is_enabled is not None: + rule.is_enabled = update_data.is_enabled + + rule.updated_at = datetime.utcnow() + await self._session.flush() + + self._cache.invalidate(tenant_id) + + logger.info( + f"[AC-AISVC-85] Updated behavior rule: tenant={tenant_id}, id={rule_id}" + ) + return rule + + async def delete_rule( + self, + tenant_id: str, + rule_id: uuid.UUID, + ) -> bool: + """ + [AC-AISVC-85] Delete a behavior rule. + """ + rule = await self.get_rule(tenant_id, rule_id) + if not rule: + return False + + await self._session.delete(rule) + await self._session.flush() + + self._cache.invalidate(tenant_id) + + logger.info( + f"[AC-AISVC-85] Deleted behavior rule: tenant={tenant_id}, id={rule_id}" + ) + return True + + async def get_enabled_rules_for_injection( + self, + tenant_id: str, + ) -> list[BehaviorRule]: + """ + [AC-AISVC-84] Get enabled rules for Prompt injection. + Uses cache for performance. + """ + cached = self._cache.get(tenant_id) + if cached is not None: + logger.debug(f"[AC-AISVC-84] Cache hit for behavior rules: tenant={tenant_id}") + return cached + + stmt = ( + select(BehaviorRule) + .where( + BehaviorRule.tenant_id == tenant_id, # type: ignore + BehaviorRule.is_enabled == True, # type: ignore + ) + .order_by(col(BehaviorRule.category)) + ) + result = await self._session.execute(stmt) + rules = list(result.scalars().all()) + + self._cache.set(tenant_id, rules) + logger.info( + f"[AC-AISVC-84] Loaded {len(rules)} enabled behavior rules from DB: tenant={tenant_id}" + ) + return rules + + def invalidate_cache(self, tenant_id: str) -> None: + """Manually invalidate cache for a tenant.""" + self._cache.invalidate(tenant_id) + + async def rule_to_info_dict(self, rule: BehaviorRule) -> dict[str, Any]: + """Convert rule entity to API response dict.""" + return { + "id": str(rule.id), + "rule_text": rule.rule_text, + "category": rule.category, + "is_enabled": rule.is_enabled, + "created_at": rule.created_at.isoformat(), + "updated_at": rule.updated_at.isoformat(), + } + + async def format_rules_for_prompt( + self, + tenant_id: str, + ) -> str: + """ + [AC-AISVC-84] Format behavior rules for Prompt injection. + Returns formatted string to append to system instruction. + """ + rules = await self.get_enabled_rules_for_injection(tenant_id) + + if not rules: + return "" + + lines = ["\n\n[行为约束 - 以下规则必须严格遵守]"] + for i, rule in enumerate(rules, 1): + lines.append(f"{i}. {rule.rule_text}") + + return "\n".join(lines) diff --git a/ai-service/app/services/guardrail/input_scanner.py b/ai-service/app/services/guardrail/input_scanner.py new file mode 100644 index 0000000..3965dbe --- /dev/null +++ b/ai-service/app/services/guardrail/input_scanner.py @@ -0,0 +1,110 @@ +""" +Input scanner for AI Service. +[AC-AISVC-83] User input pre-detection (logging only, no blocking). +""" + +import logging +from typing import Any + +from app.models.entities import ( + ForbiddenWord, + InputScanResult, +) +from app.services.guardrail.word_service import ForbiddenWordService + +logger = logging.getLogger(__name__) + + +class InputScanner: + """ + [AC-AISVC-83] Input scanner for pre-detection of forbidden words. + + Features: + - Scans user input for forbidden words + - Records matched words and categories in metadata + - Does NOT block the request (only logging) + - Used for monitoring and analytics + """ + + def __init__(self, word_service: ForbiddenWordService): + self._word_service = word_service + + async def scan( + self, + text: str, + tenant_id: str, + ) -> InputScanResult: + """ + [AC-AISVC-83] Scan user input for forbidden words. + + Args: + text: User input text to scan + tenant_id: Tenant ID for isolation + + Returns: + InputScanResult with flagged status and matched words + """ + if not text or not text.strip(): + return InputScanResult(flagged=False) + + words = await self._word_service.get_enabled_words_for_filtering(tenant_id) + + if not words: + return InputScanResult(flagged=False) + + matched_words: list[str] = [] + matched_categories: list[str] = [] + matched_word_entities: list[ForbiddenWord] = [] + + for word in words: + if word.word in text: + matched_words.append(word.word) + if word.category not in matched_categories: + matched_categories.append(word.category) + matched_word_entities.append(word) + + if matched_words: + logger.info( + f"[AC-AISVC-83] Input flagged: tenant={tenant_id}, " + f"matched_words={matched_words}, categories={matched_categories}" + ) + + for word_entity in matched_word_entities: + try: + await self._word_service.increment_hit_count(tenant_id, word_entity.id) + except Exception as e: + logger.warning( + f"Failed to increment hit count for word {word_entity.id}: {e}" + ) + + return InputScanResult( + flagged=len(matched_words) > 0, + matched_words=matched_words, + matched_categories=matched_categories, + ) + + async def scan_and_enrich_metadata( + self, + text: str, + tenant_id: str, + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + [AC-AISVC-83] Scan input and enrich metadata with scan result. + + Args: + text: User input text to scan + tenant_id: Tenant ID for isolation + metadata: Existing metadata dict to enrich + + Returns: + Enriched metadata with input_flagged and matched info + """ + result = await self.scan(text, tenant_id) + + if metadata is None: + metadata = {} + + metadata.update(result.to_dict()) + + return metadata diff --git a/ai-service/app/services/guardrail/output_filter.py b/ai-service/app/services/guardrail/output_filter.py new file mode 100644 index 0000000..5e04665 --- /dev/null +++ b/ai-service/app/services/guardrail/output_filter.py @@ -0,0 +1,154 @@ +""" +Output filter for AI Service. +[AC-AISVC-82] LLM output post-filtering with mask/replace/block strategies. +""" + +import logging +from typing import Any + +from app.models.entities import ( + ForbiddenWord, + ForbiddenWordStrategy, + GuardrailResult, +) +from app.services.guardrail.word_service import ForbiddenWordService + +logger = logging.getLogger(__name__) + + +class OutputFilter: + """ + [AC-AISVC-82] Output filter for post-filtering LLM responses. + + Features: + - Scans LLM output for forbidden words + - Applies mask/replace/block strategies + - Returns fallback reply for block strategy + - Records triggered words in metadata + """ + + DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您" + + def __init__(self, word_service: ForbiddenWordService): + self._word_service = word_service + + async def filter( + self, + reply: str, + tenant_id: str, + ) -> GuardrailResult: + """ + [AC-AISVC-82] Filter LLM output for forbidden words. + + Args: + reply: LLM generated reply to filter + tenant_id: Tenant ID for isolation + + Returns: + GuardrailResult with filtered reply and trigger info + """ + if not reply or not reply.strip(): + return GuardrailResult(reply=reply) + + words = await self._word_service.get_enabled_words_for_filtering(tenant_id) + + if not words: + return GuardrailResult(reply=reply) + + triggered_words: list[str] = [] + triggered_categories: list[str] = [] + filtered_reply = reply + blocked = False + fallback_reply = self.DEFAULT_FALLBACK_REPLY + + for word in words: + if word.word in filtered_reply: + triggered_words.append(word.word) + if word.category not in triggered_categories: + triggered_categories.append(word.category) + + if word.strategy == ForbiddenWordStrategy.BLOCK.value: + blocked = True + fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY + logger.warning( + f"[AC-AISVC-82] Output blocked by forbidden word: tenant={tenant_id}, " + f"word={word.word}, category={word.category}" + ) + break + + elif word.strategy == ForbiddenWordStrategy.MASK.value: + filtered_reply = filtered_reply.replace(word.word, "*" * len(word.word)) + logger.info( + f"[AC-AISVC-82] Output masked: tenant={tenant_id}, word={word.word}" + ) + + elif word.strategy == ForbiddenWordStrategy.REPLACE.value: + replacement = word.replacement or "" + filtered_reply = filtered_reply.replace(word.word, replacement) + logger.info( + f"[AC-AISVC-82] Output replaced: tenant={tenant_id}, " + f"word={word.word} -> {replacement}" + ) + + if blocked: + return GuardrailResult( + reply=fallback_reply, + blocked=True, + triggered_words=triggered_words, + triggered_categories=triggered_categories, + ) + + if triggered_words: + logger.info( + f"[AC-AISVC-82] Output filtered: tenant={tenant_id}, " + f"triggered_words={triggered_words}, categories={triggered_categories}" + ) + + for word_entity in self._get_triggered_word_entities(words, triggered_words): + try: + await self._word_service.increment_hit_count(tenant_id, word_entity.id) + except Exception as e: + logger.warning( + f"Failed to increment hit count for word {word_entity.id}: {e}" + ) + + return GuardrailResult( + reply=filtered_reply, + blocked=False, + triggered_words=triggered_words, + triggered_categories=triggered_categories, + ) + + def _get_triggered_word_entities( + self, + words: list[ForbiddenWord], + triggered_words: list[str], + ) -> list[ForbiddenWord]: + """Get word entities for triggered words.""" + return [w for w in words if w.word in triggered_words] + + async def filter_and_enrich_metadata( + self, + reply: str, + tenant_id: str, + metadata: dict[str, Any] | None = None, + ) -> tuple[str, dict[str, Any]]: + """ + [AC-AISVC-82] Filter output and enrich metadata with filter result. + + Args: + reply: LLM generated reply to filter + tenant_id: Tenant ID for isolation + metadata: Existing metadata dict to enrich + + Returns: + Tuple of (filtered_reply, enriched_metadata) + """ + result = await self.filter(reply, tenant_id) + + if metadata is None: + metadata = {} + + metadata.update(result.to_dict()) + + return result.reply, metadata diff --git a/ai-service/app/services/guardrail/streaming_filter.py b/ai-service/app/services/guardrail/streaming_filter.py new file mode 100644 index 0000000..4e033f5 --- /dev/null +++ b/ai-service/app/services/guardrail/streaming_filter.py @@ -0,0 +1,366 @@ +""" +Streaming guardrail for AI Service. +[AC-AISVC-82] Streaming mode forbidden word detection with sliding window buffer. +""" + +import logging +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from app.models.entities import ( + ForbiddenWord, + ForbiddenWordStrategy, +) +from app.services.guardrail.word_service import ForbiddenWordService + +logger = logging.getLogger(__name__) + + +class StreamingGuardrailState(str, Enum): + """State of streaming guardrail.""" + ACTIVE = "active" + BLOCKED = "blocked" + COMPLETED = "completed" + + +@dataclass +class StreamingGuardrailResult: + """Result from streaming guardrail processing.""" + delta: str + should_stop: bool = False + fallback_reply: str | None = None + triggered_words: list[str] = field(default_factory=list) + triggered_categories: list[str] = field(default_factory=list) + + +class StreamingGuardrail: + """ + [AC-AISVC-82] Streaming guardrail with sliding window buffer. + + Features: + - Maintains a sliding window buffer for forbidden word detection + - Detects forbidden words across chunk boundaries + - Applies mask/replace strategies incrementally + - Stops streaming and returns fallback for block strategy + - Final check at stream end + """ + + DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您" + DEFAULT_WINDOW_SIZE = 50 + + def __init__( + self, + word_service: ForbiddenWordService, + window_size: int = DEFAULT_WINDOW_SIZE, + ): + self._word_service = word_service + self._window_size = window_size + self._buffer = "" + self._state = StreamingGuardrailState.ACTIVE + self._triggered_words: list[str] = [] + self._triggered_categories: list[str] = [] + self._fallback_reply = self.DEFAULT_FALLBACK_REPLY + self._words_cache: list[ForbiddenWord] | None = None + self._max_word_length = 0 + + async def initialize(self, tenant_id: str) -> None: + """ + Initialize the guardrail with tenant's forbidden words. + Must be called before processing chunks. + """ + self._words_cache = await self._word_service.get_enabled_words_for_filtering(tenant_id) + + if self._words_cache: + self._max_word_length = max(len(w.word) for w in self._words_cache) + effective_window = max(self._window_size, self._max_word_length) + if effective_window > self._window_size: + logger.info( + f"[AC-AISVC-82] Adjusted window size from {self._window_size} to {effective_window} " + f"to accommodate max word length {self._max_word_length}" + ) + self._window_size = effective_window + + logger.debug( + f"[AC-AISVC-82] StreamingGuardrail initialized: " + f"words_count={len(self._words_cache) if self._words_cache else 0}, " + f"window_size={self._window_size}" + ) + + async def process_chunk( + self, + delta: str, + tenant_id: str, + ) -> StreamingGuardrailResult: + """ + Process a streaming chunk with sliding window detection. + + Args: + delta: New text chunk from LLM + tenant_id: Tenant ID for isolation + + Returns: + StreamingGuardrailResult with processed delta and state + """ + if self._state == StreamingGuardrailState.BLOCKED: + return StreamingGuardrailResult( + delta="", + should_stop=True, + fallback_reply=self._fallback_reply, + ) + + if self._state == StreamingGuardrailState.COMPLETED: + return StreamingGuardrailResult(delta="") + + if not self._words_cache: + await self.initialize(tenant_id) + + if not self._words_cache: + return StreamingGuardrailResult(delta=delta) + + self._buffer += delta + + result = self._check_buffer() + + if result.should_stop: + self._state = StreamingGuardrailState.BLOCKED + logger.warning( + f"[AC-AISVC-82] Streaming blocked: tenant={tenant_id}, " + f"triggered_words={result.triggered_words}" + ) + + return result + + def _check_buffer(self) -> StreamingGuardrailResult: + """ + Check buffer for forbidden words and apply strategies. + """ + triggered_words: list[str] = [] + triggered_categories: list[str] = [] + filtered_buffer = self._buffer + blocked = False + fallback_reply = self.DEFAULT_FALLBACK_REPLY + + if not self._words_cache: + return StreamingGuardrailResult(delta=self._buffer) + + for word in self._words_cache: + if word.word in filtered_buffer: + triggered_words.append(word.word) + if word.category not in triggered_categories: + triggered_categories.append(word.category) + + if word.strategy == ForbiddenWordStrategy.BLOCK.value: + blocked = True + fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY + break + + elif word.strategy == ForbiddenWordStrategy.MASK.value: + filtered_buffer = filtered_buffer.replace(word.word, "*" * len(word.word)) + + elif word.strategy == ForbiddenWordStrategy.REPLACE.value: + replacement = word.replacement or "" + filtered_buffer = filtered_buffer.replace(word.word, replacement) + + self._triggered_words.extend([w for w in triggered_words if w not in self._triggered_words]) + self._triggered_categories.extend([c for c in triggered_categories if c not in self._triggered_categories]) + + if blocked: + return StreamingGuardrailResult( + delta="", + should_stop=True, + fallback_reply=fallback_reply, + triggered_words=triggered_words, + triggered_categories=triggered_categories, + ) + + safe_output, remaining = self._split_buffer(filtered_buffer) + + self._buffer = remaining + + return StreamingGuardrailResult( + delta=safe_output, + should_stop=False, + triggered_words=triggered_words, + triggered_categories=triggered_categories, + ) + + def _split_buffer(self, buffer: str) -> tuple[str, str]: + """ + Split buffer into safe output and remaining buffer. + + Safe output: characters that are definitely safe (before window boundary) + Remaining: characters that might be part of a forbidden word + """ + if len(buffer) <= self._window_size: + return "", buffer + + safe_length = len(buffer) - self._window_size + safe_output = buffer[:safe_length] + remaining = buffer[safe_length:] + + return safe_output, remaining + + async def finalize(self, tenant_id: str) -> StreamingGuardrailResult: + """ + Finalize the stream and process remaining buffer. + Must be called at the end of streaming. + + Args: + tenant_id: Tenant ID for isolation + + Returns: + StreamingGuardrailResult with final delta + """ + if self._state == StreamingGuardrailState.BLOCKED: + self._state = StreamingGuardrailState.COMPLETED + return StreamingGuardrailResult( + delta="", + should_stop=True, + fallback_reply=self._fallback_reply, + triggered_words=self._triggered_words, + triggered_categories=self._triggered_categories, + ) + + if not self._words_cache: + await self.initialize(tenant_id) + + result = self._final_check() + + self._state = StreamingGuardrailState.COMPLETED + + if self._triggered_words: + logger.info( + f"[AC-AISVC-82] Streaming finalized with triggers: tenant={tenant_id}, " + f"triggered_words={self._triggered_words}" + ) + + for word_entity in self._get_triggered_word_entities(): + try: + await self._word_service.increment_hit_count(tenant_id, word_entity.id) + except Exception as e: + logger.warning( + f"Failed to increment hit count for word {word_entity.id}: {e}" + ) + + return result + + def _final_check(self) -> StreamingGuardrailResult: + """ + Final check of remaining buffer. + """ + if not self._buffer: + return StreamingGuardrailResult( + delta="", + triggered_words=self._triggered_words, + triggered_categories=self._triggered_categories, + ) + + if not self._words_cache: + return StreamingGuardrailResult( + delta=self._buffer, + triggered_words=self._triggered_words, + triggered_categories=self._triggered_categories, + ) + + filtered_buffer = self._buffer + blocked = False + fallback_reply = self.DEFAULT_FALLBACK_REPLY + + for word in self._words_cache: + if word.word in filtered_buffer: + if word.word not in self._triggered_words: + self._triggered_words.append(word.word) + if word.category not in self._triggered_categories: + self._triggered_categories.append(word.category) + + if word.strategy == ForbiddenWordStrategy.BLOCK.value: + blocked = True + fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY + break + + elif word.strategy == ForbiddenWordStrategy.MASK.value: + filtered_buffer = filtered_buffer.replace(word.word, "*" * len(word.word)) + + elif word.strategy == ForbiddenWordStrategy.REPLACE.value: + replacement = word.replacement or "" + filtered_buffer = filtered_buffer.replace(word.word, replacement) + + if blocked: + return StreamingGuardrailResult( + delta="", + should_stop=True, + fallback_reply=fallback_reply, + triggered_words=self._triggered_words, + triggered_categories=self._triggered_categories, + ) + + self._buffer = "" + + return StreamingGuardrailResult( + delta=filtered_buffer, + triggered_words=self._triggered_words, + triggered_categories=self._triggered_categories, + ) + + def _get_triggered_word_entities(self) -> list[ForbiddenWord]: + """Get word entities for triggered words.""" + if not self._words_cache: + return [] + return [w for w in self._words_cache if w.word in self._triggered_words] + + def get_triggered_info(self) -> dict[str, Any]: + """Get triggered words and categories info.""" + return { + "triggered_words": self._triggered_words, + "triggered_categories": self._triggered_categories, + "guardrail_triggered": len(self._triggered_words) > 0, + } + + def reset(self) -> None: + """Reset the guardrail state for reuse.""" + self._buffer = "" + self._state = StreamingGuardrailState.ACTIVE + self._triggered_words = [] + self._triggered_categories = [] + self._fallback_reply = self.DEFAULT_FALLBACK_REPLY + self._words_cache = None + self._max_word_length = 0 + + +async def wrap_stream_with_guardrail( + stream: AsyncIterator[str], + guardrail: StreamingGuardrail, + tenant_id: str, +) -> AsyncIterator[tuple[str, bool, str | None]]: + """ + Wrap an async stream with guardrail processing. + + Args: + stream: Original LLM output stream + guardrail: StreamingGuardrail instance + tenant_id: Tenant ID for isolation + + Yields: + Tuple of (delta, should_stop, fallback_reply) + """ + await guardrail.initialize(tenant_id) + + async for delta in stream: + result = await guardrail.process_chunk(delta, tenant_id) + + if result.delta: + yield (result.delta, False, None) + + if result.should_stop: + yield ("", True, result.fallback_reply) + return + + final_result = await guardrail.finalize(tenant_id) + + if final_result.delta: + yield (final_result.delta, False, None) + + if final_result.should_stop: + yield ("", True, final_result.fallback_reply) diff --git a/ai-service/app/services/guardrail/word_service.py b/ai-service/app/services/guardrail/word_service.py new file mode 100644 index 0000000..a6dc541 --- /dev/null +++ b/ai-service/app/services/guardrail/word_service.py @@ -0,0 +1,297 @@ +""" +Forbidden word service for AI Service. +[AC-AISVC-78~AC-AISVC-81] Forbidden word CRUD and hit statistics management. +""" + +import logging +import time +import uuid +from collections.abc import Sequence +from datetime import datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col + +from app.models.entities import ( + ForbiddenWord, + ForbiddenWordCategory, + ForbiddenWordCreate, + ForbiddenWordStrategy, + ForbiddenWordUpdate, +) + +logger = logging.getLogger(__name__) + +WORD_CACHE_TTL_SECONDS = 60 + + +class WordCache: + """ + [AC-AISVC-82] In-memory cache for forbidden words. + Key: tenant_id + Value: (words_list, cached_at) + TTL: 60 seconds + """ + + def __init__(self, ttl_seconds: int = WORD_CACHE_TTL_SECONDS): + self._cache: dict[str, tuple[list[ForbiddenWord], float]] = {} + self._ttl = ttl_seconds + + def get(self, tenant_id: str) -> list[ForbiddenWord] | None: + """Get cached words if not expired.""" + if tenant_id in self._cache: + words, cached_at = self._cache[tenant_id] + if time.time() - cached_at < self._ttl: + return words + else: + del self._cache[tenant_id] + return None + + def set(self, tenant_id: str, words: list[ForbiddenWord]) -> None: + """Cache words for a tenant.""" + self._cache[tenant_id] = (words, time.time()) + + def invalidate(self, tenant_id: str) -> None: + """Invalidate cache for a tenant.""" + if tenant_id in self._cache: + del self._cache[tenant_id] + logger.debug(f"Invalidated word cache for tenant={tenant_id}") + + +_word_cache = WordCache() + + +class ForbiddenWordService: + """ + [AC-AISVC-78~AC-AISVC-81] Service for managing forbidden words. + + Features: + - Word CRUD with tenant isolation + - Hit count statistics + - In-memory caching with TTL + - Cache invalidation on CRUD operations + - Support for mask/replace/block strategies + """ + + VALID_CATEGORIES = [c.value for c in ForbiddenWordCategory] + VALID_STRATEGIES = [s.value for s in ForbiddenWordStrategy] + + def __init__(self, session: AsyncSession): + self._session = session + self._cache = _word_cache + + async def create_word( + self, + tenant_id: str, + create_data: ForbiddenWordCreate, + ) -> ForbiddenWord: + """ + [AC-AISVC-78] Create a new forbidden word. + """ + if create_data.category not in self.VALID_CATEGORIES: + raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}") + + if create_data.strategy not in self.VALID_STRATEGIES: + raise ValueError(f"Invalid strategy. Must be one of: {self.VALID_STRATEGIES}") + + if create_data.strategy == ForbiddenWordStrategy.REPLACE.value and not create_data.replacement: + raise ValueError("replacement is required when strategy is 'replace'") + + if create_data.strategy == ForbiddenWordStrategy.BLOCK.value and not create_data.fallback_reply: + logger.warning( + f"[AC-AISVC-78] Creating block word without fallback_reply: tenant={tenant_id}, word={create_data.word}" + ) + + word = ForbiddenWord( + tenant_id=tenant_id, + word=create_data.word, + category=create_data.category, + strategy=create_data.strategy, + replacement=create_data.replacement, + fallback_reply=create_data.fallback_reply, + is_enabled=True, + hit_count=0, + ) + self._session.add(word) + await self._session.flush() + + self._cache.invalidate(tenant_id) + + logger.info( + f"[AC-AISVC-78] Created forbidden word: tenant={tenant_id}, " + f"id={word.id}, word={word.word}, strategy={word.strategy}" + ) + return word + + async def list_words( + self, + tenant_id: str, + category: str | None = None, + is_enabled: bool | None = None, + ) -> Sequence[ForbiddenWord]: + """ + [AC-AISVC-79] List words for a tenant with optional filters. + """ + stmt = select(ForbiddenWord).where( + ForbiddenWord.tenant_id == tenant_id # type: ignore + ) + + if category is not None: + stmt = stmt.where(ForbiddenWord.category == category) # type: ignore + + if is_enabled is not None: + stmt = stmt.where(ForbiddenWord.is_enabled == is_enabled) # type: ignore + + stmt = stmt.order_by(col(ForbiddenWord.category), col(ForbiddenWord.created_at).desc()) + result = await self._session.execute(stmt) + return result.scalars().all() + + async def get_word( + self, + tenant_id: str, + word_id: uuid.UUID, + ) -> ForbiddenWord | None: + """ + [AC-AISVC-79] Get word by ID with tenant isolation. + """ + stmt = select(ForbiddenWord).where( + ForbiddenWord.tenant_id == tenant_id, # type: ignore + ForbiddenWord.id == word_id, # type: ignore + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def update_word( + self, + tenant_id: str, + word_id: uuid.UUID, + update_data: ForbiddenWordUpdate, + ) -> ForbiddenWord | None: + """ + [AC-AISVC-80] Update a forbidden word. + """ + word = await self.get_word(tenant_id, word_id) + if not word: + return None + + if update_data.word is not None: + word.word = update_data.word + if update_data.category is not None: + if update_data.category not in self.VALID_CATEGORIES: + raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}") + word.category = update_data.category + if update_data.strategy is not None: + if update_data.strategy not in self.VALID_STRATEGIES: + raise ValueError(f"Invalid strategy. Must be one of: {self.VALID_STRATEGIES}") + word.strategy = update_data.strategy + if update_data.replacement is not None: + word.replacement = update_data.replacement + if update_data.fallback_reply is not None: + word.fallback_reply = update_data.fallback_reply + if update_data.is_enabled is not None: + word.is_enabled = update_data.is_enabled + + word.updated_at = datetime.utcnow() + await self._session.flush() + + self._cache.invalidate(tenant_id) + + logger.info( + f"[AC-AISVC-80] Updated forbidden word: tenant={tenant_id}, id={word_id}" + ) + return word + + async def delete_word( + self, + tenant_id: str, + word_id: uuid.UUID, + ) -> bool: + """ + [AC-AISVC-81] Delete a forbidden word. + """ + word = await self.get_word(tenant_id, word_id) + if not word: + return False + + await self._session.delete(word) + await self._session.flush() + + self._cache.invalidate(tenant_id) + + logger.info( + f"[AC-AISVC-81] Deleted forbidden word: tenant={tenant_id}, id={word_id}" + ) + return True + + async def increment_hit_count( + self, + tenant_id: str, + word_id: uuid.UUID, + ) -> bool: + """ + [AC-AISVC-79] Increment hit count for a word. + """ + word = await self.get_word(tenant_id, word_id) + if not word: + return False + + word.hit_count += 1 + word.updated_at = datetime.utcnow() + await self._session.flush() + + logger.debug( + f"[AC-AISVC-79] Incremented hit count for word: tenant={tenant_id}, " + f"id={word_id}, hit_count={word.hit_count}" + ) + return True + + async def get_enabled_words_for_filtering( + self, + tenant_id: str, + ) -> list[ForbiddenWord]: + """ + [AC-AISVC-82] Get enabled words for filtering. + Uses cache for performance. + """ + cached = self._cache.get(tenant_id) + if cached is not None: + logger.debug(f"[AC-AISVC-82] Cache hit for words: tenant={tenant_id}") + return cached + + stmt = ( + select(ForbiddenWord) + .where( + ForbiddenWord.tenant_id == tenant_id, # type: ignore + ForbiddenWord.is_enabled == True, # type: ignore + ) + .order_by(col(ForbiddenWord.category)) + ) + result = await self._session.execute(stmt) + words = list(result.scalars().all()) + + self._cache.set(tenant_id, words) + logger.info( + f"[AC-AISVC-82] Loaded {len(words)} enabled words from DB: tenant={tenant_id}" + ) + return words + + def invalidate_cache(self, tenant_id: str) -> None: + """Manually invalidate cache for a tenant.""" + self._cache.invalidate(tenant_id) + + async def word_to_info_dict(self, word: ForbiddenWord) -> dict[str, Any]: + """Convert word entity to API response dict.""" + return { + "id": str(word.id), + "word": word.word, + "category": word.category, + "strategy": word.strategy, + "replacement": word.replacement, + "fallback_reply": word.fallback_reply, + "is_enabled": word.is_enabled, + "hit_count": word.hit_count, + "created_at": word.created_at.isoformat(), + "updated_at": word.updated_at.isoformat(), + } diff --git a/docs/progress/ai-service-progress.md b/docs/progress/ai-service-progress.md index 260af8f..ca1ecdc 100644 --- a/docs/progress/ai-service-progress.md +++ b/docs/progress/ai-service-progress.md @@ -6,7 +6,7 @@ - module: `ai-service` - feature: `AISVC` (Python AI 中台) -- status: 🔄 进行中 (Phase 12) +- status: 🔄 进行中 (Phase 14 完成) --- @@ -35,27 +35,29 @@ - [x] Phase 10: Prompt 模板化 (80%) 🔄 (T10.1-T10.8 完成,T10.9-T10.10 待集成阶段) - [x] Phase 11: 多知识库管理 (63%) 🔄 (T11.1-T11.5 完成,T11.6-T11.8 待集成阶段) - [x] Phase 12: 意图识别与规则引擎 (71%) 🔄 (T12.1-T12.5 完成,T12.6-T12.7 待集成阶段) +- [x] Phase 13: 话术流程引擎 (0%) ⏳ 待处理 +- [x] Phase 14: 输出护栏 (88%) ✅ (T14.1-T14.7 完成,T14.8 单元测试留到集成阶段) --- ## 🔄 Current Phase ### Goal -Phase 11 多知识库管理核心功能已完成 (T11.1-T11.5),T11.6(OptimizedRetriever 多 Collection 检索)、T11.7(kb_default 迁移)、T11.8(单元测试)留待集成阶段。 +Phase 14 输出护栏核心功能已完成 (T14.1-T14.7),T14.8(单元测试)留到集成阶段。 -### Completed Tasks (Phase 11) +### Completed Tasks (Phase 14) -- [x] T11.1 扩展 `KnowledgeBase` 实体:新增 `kb_type`、`priority`、`is_enabled`、`doc_count` 字段 `[AC-AISVC-59]` ✅ -- [x] T11.2 实现知识库 CRUD 服务:创建时初始化 Qdrant Collection,删除时清理 Collection `[AC-AISVC-59, AC-AISVC-61, AC-AISVC-62]` ✅ -- [x] T11.3 实现知识库管理 API:`POST/GET/PUT/DELETE /admin/kb/knowledge-bases` `[AC-AISVC-59, AC-AISVC-60, AC-AISVC-61, AC-AISVC-62]` ✅ -- [x] T11.4 升级 Qdrant Collection 命名:`kb_{tenant_id}_{kb_id}`,兼容现有 `kb_{tenant_id}` `[AC-AISVC-63]` ✅ -- [x] T11.5 修改文档上传流程:支持指定 `kbId` 参数,索引到对应 Collection `[AC-AISVC-63]` ✅ +- [x] T14.1 定义 `ForbiddenWord` 和 `BehaviorRule` SQLModel 实体,创建数据库表 `[AC-AISVC-78, AC-AISVC-84]` ✅ +- [x] T14.2 实现 `ForbiddenWordService`:禁词 CRUD + 命中统计 `[AC-AISVC-78, AC-AISVC-79, AC-AISVC-80, AC-AISVC-81]` ✅ +- [x] T14.3 实现 `BehaviorRuleService`:行为规则 CRUD `[AC-AISVC-84, AC-AISVC-85]` ✅ +- [x] T14.4 实现 `InputScanner`:用户输入前置禁词检测(仅记录,不阻断) `[AC-AISVC-83]` ✅ +- [x] T14.5 实现 `OutputFilter`:LLM 输出后置过滤(mask/replace/block 三种策略) `[AC-AISVC-82]` ✅ +- [x] T14.6 实现 Streaming 模式下的滑动窗口禁词检测 `[AC-AISVC-82]` ✅ +- [x] T14.7 实现护栏管理 API:`/admin/guardrails` 相关端点 `[AC-AISVC-78~AC-AISVC-85]` ✅ -### Pending Tasks (Phase 11 - 集成阶段) +### Pending Tasks (Phase 14 - 集成阶段) -- [ ] T11.6 修改 `OptimizedRetriever`:支持 `target_kb_ids` 参数,实现多 Collection 并行检索 `[AC-AISVC-64]` -- [ ] T11.7 实现 `kb_default` 自动迁移:首次启动时为现有数据创建默认知识库记录 `[AC-AISVC-59]` -- [ ] T11.8 编写多知识库服务单元测试 `[AC-AISVC-59~AC-AISVC-64]` +- [ ] T14.8 编写输出护栏服务单元测试 `[AC-AISVC-78~AC-AISVC-85]` --- @@ -66,42 +68,73 @@ Phase 11 多知识库管理核心功能已完成 (T11.1-T11.5),T11.6(Optimiz - `ai-service/` - `app/` - `api/` - FastAPI 路由层 - - `admin/intent_rules.py` - 意图规则管理 API ✅ - - `admin/prompt_templates.py` - Prompt 模板管理 API ✅ + - `admin/guardrails.py` - 护栏管理 API ✅ - `models/` - Pydantic 模型和 SQLModel 实体 - - `entities.py` - IntentRule, PromptTemplate, PromptTemplateVersion 实体 ✅ + - `entities.py` - ForbiddenWord, BehaviorRule, GuardrailResult, InputScanResult 实体 ✅ - `services/` - - `intent/` - 意图识别服务 ✅ + - `guardrail/` - 输出护栏服务 ✅ - `__init__.py` - 模块导出 - - `rule_service.py` - 规则 CRUD、命中统计、缓存 - - `router.py` - IntentRouter 匹配引擎 - - `prompt/` - Prompt 模板服务 ✅ - - `__init__.py` - 模块导出 - - `template_service.py` - 模板 CRUD、版本管理、发布/回滚、缓存 - - `variable_resolver.py` - 变量替换引擎 + - `word_service.py` - 禁词 CRUD、命中统计、缓存 + - `behavior_service.py` - 行为规则 CRUD、缓存、Prompt 注入格式化 + - `input_scanner.py` - 用户输入前置检测(仅记录,不阻断) + - `output_filter.py` - LLM 输出后置过滤(mask/replace/block) + - `streaming_filter.py` - Streaming 滑动窗口检测 ### Key Decisions (Why / Impact) -- decision: 意图规则数据库驱动 - reason: 支持动态配置意图识别规则,无需重启服务 - impact: 规则存储在 PostgreSQL,支持按租户隔离 +- decision: 三种禁词替换策略 + reason: 满足不同场景的内容合规需求 + impact: mask 星号替换、replace 指定文本替换、block 拦截整条回复返回兜底话术 -- decision: 关键词 + 正则双匹配机制 - reason: 关键词匹配快速高效,正则匹配支持复杂模式 - impact: 先关键词匹配再正则匹配,优先级高的规则先匹配 +- decision: 输入检测不阻断 + reason: 用户输入包含禁词时仍需正常处理,仅记录用于监控分析 + impact: InputScanner 返回 flagged 状态和匹配信息,不抛异常 + +- decision: Streaming 滑动窗口检测 + reason: 流式输出无法预知完整内容,需要缓冲区检测跨 chunk 的禁词 + impact: 维护滑动窗口 buffer(默认 50 字符,自动调整到最长禁词长度),检测到禁词后立即停止 + +- decision: 行为规则注入 Prompt + reason: 行为规则作为 LLM 的行为约束,不进行运行时检测 + impact: BehaviorRuleService 提供 format_rules_for_prompt() 方法,追加到系统指令末尾 - decision: 内存缓存 + TTL 策略 - reason: 减少数据库查询,提升匹配性能 + reason: 减少数据库查询,提升过滤性能 impact: 缓存 TTL=60s,CRUD 操作时主动失效 -- decision: 四种响应类型 - reason: 支持不同的处理链路 - impact: fixed 直接返回、rag 定向检索、flow 进入流程、transfer 转人工 - --- ## 🧾 Session History +### Session #10 (2026-02-27) +- completed: + - T14.1-T14.7 输出护栏核心功能 + - 实现 ForbiddenWord 和 BehaviorRule 实体 + - 实现 ForbiddenWordService(CRUD、命中统计、缓存) + - 实现 BehaviorRuleService(CRUD、缓存、Prompt 注入格式化) + - 实现 InputScanner(用户输入前置检测,仅记录不阻断) + - 实现 OutputFilter(LLM 输出后置过滤,mask/replace/block 三种策略) + - 实现 StreamingGuardrail(Streaming 滑动窗口检测) + - 实现护栏管理 API(禁词和行为规则 CRUD) +- changes: + - 新增 `app/models/entities.py` ForbiddenWord, BehaviorRule, GuardrailResult, InputScanResult 实体 + - 新增 `app/services/guardrail/__init__.py` 模块导出 + - 新增 `app/services/guardrail/word_service.py` 禁词服务 + - 新增 `app/services/guardrail/behavior_service.py` 行为规则服务 + - 新增 `app/services/guardrail/input_scanner.py` 输入扫描器 + - 新增 `app/services/guardrail/output_filter.py` 输出过滤器 + - 新增 `app/services/guardrail/streaming_filter.py` Streaming 过滤器 + - 新增 `app/api/admin/guardrails.py` 护栏管理 API + - 更新 `app/api/admin/__init__.py` 导出新路由 + - 更新 `app/main.py` 注册新路由 + - 更新 `spec/ai-service/tasks.md` 标记任务完成 +- notes: + - T14.8(单元测试)留到集成阶段 + - 禁词检测三种策略:mask(星号替换)、replace(指定文本替换)、block(拦截返回兜底话术) + - InputScanner 仅记录命中,不阻断请求 + - OutputFilter 应用 mask/replace/block 策略 + - StreamingGuardrail 使用滑动窗口 buffer(默认 50 字符,自动调整) + ### Session #9 (2026-02-27) - completed: - T11.1-T11.5 多知识库管理核心功能