Merge pull request '[AC-AISVC-50] 合入第一个稳定版本' (#2) from feature/prompt-unification-and-logging into main
Reviewed-on: #2
This commit is contained in:
commit
60bf649d96
|
|
@ -0,0 +1,21 @@
|
|||
# AI Service Environment Variables
|
||||
# Copy this file to .env and modify as needed
|
||||
|
||||
# LLM Configuration (OpenAI)
|
||||
AI_SERVICE_LLM_PROVIDER=openai
|
||||
AI_SERVICE_LLM_API_KEY=your-api-key-here
|
||||
AI_SERVICE_LLM_BASE_URL=https://api.openai.com/v1
|
||||
AI_SERVICE_LLM_MODEL=gpt-4o-mini
|
||||
|
||||
# If using DeepSeek
|
||||
# AI_SERVICE_LLM_PROVIDER=deepseek
|
||||
# AI_SERVICE_LLM_API_KEY=your-deepseek-api-key
|
||||
# AI_SERVICE_LLM_MODEL=deepseek-chat
|
||||
|
||||
# Ollama Configuration (for embedding model)
|
||||
AI_SERVICE_OLLAMA_BASE_URL=http://ollama:11434
|
||||
AI_SERVICE_OLLAMA_EMBEDDING_MODEL=nomic-embed-text
|
||||
|
||||
# Frontend API Key (required for admin panel authentication)
|
||||
# Get this key from the backend logs after first startup, or from /admin/api-keys
|
||||
VITE_APP_API_KEY=your-frontend-api-key-here
|
||||
|
|
@ -162,5 +162,6 @@ cython_debug/
|
|||
|
||||
# Project specific
|
||||
ai-service/uploads/
|
||||
ai-service/config/
|
||||
*.local
|
||||
|
||||
|
|
|
|||
295
README.md
295
README.md
|
|
@ -1,3 +1,294 @@
|
|||
# ai-robot-core
|
||||
# AI Robot Core
|
||||
|
||||
ai中台业务的能力支撑
|
||||
AI中台业务的能力支撑,提供智能客服、RAG知识库检索、LLM对话等核心能力。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ai-robot-core/
|
||||
├── ai-service/ # Python 后端服务
|
||||
│ ├── app/ # FastAPI 应用
|
||||
│ ├── tests/ # 测试用例
|
||||
│ ├── Dockerfile # 后端镜像
|
||||
│ └── pyproject.toml # Python 依赖
|
||||
├── ai-service-admin/ # Vue 前端管理界面
|
||||
│ ├── src/ # Vue 源码
|
||||
│ ├── Dockerfile # 前端镜像
|
||||
│ ├── nginx.conf # Nginx 配置
|
||||
│ └── package.json # Node 依赖
|
||||
├── docker-compose.yaml # 容器编排
|
||||
├── .env.example # 环境变量示例
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **多租户支持**: 通过 X-Tenant-Id 头实现租户隔离
|
||||
- **RAG 知识库**: 基于 Qdrant 的向量检索增强生成
|
||||
- **LLM 集成**: 支持 OpenAI、DeepSeek、Ollama 等多种 LLM 提供商
|
||||
- **SSE 流式输出**: 支持 Server-Sent Events 实时响应
|
||||
- **置信度评估**: 自动评估回复质量,低置信度时建议转人工
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Docker 20.10+
|
||||
- Docker Compose 2.0+
|
||||
|
||||
### 部署步骤
|
||||
|
||||
#### 1. 克隆代码
|
||||
|
||||
```bash
|
||||
git clone http://49.232.209.156:3005/MerCry/ai-robot-core.git
|
||||
cd ai-robot-core
|
||||
```
|
||||
|
||||
#### 2. 配置环境变量
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
编辑 `.env` 文件,配置 LLM API:
|
||||
|
||||
```env
|
||||
# OpenAI 配置
|
||||
AI_SERVICE_LLM_PROVIDER=openai
|
||||
AI_SERVICE_LLM_API_KEY=your-openai-api-key
|
||||
AI_SERVICE_LLM_BASE_URL=https://api.openai.com/v1
|
||||
AI_SERVICE_LLM_MODEL=gpt-4o-mini
|
||||
|
||||
# 或使用 DeepSeek
|
||||
# AI_SERVICE_LLM_PROVIDER=deepseek
|
||||
# AI_SERVICE_LLM_API_KEY=your-deepseek-api-key
|
||||
# AI_SERVICE_LLM_MODEL=deepseek-chat
|
||||
```
|
||||
|
||||
#### 3. 启动服务
|
||||
|
||||
```bash
|
||||
# Docker Compose V2 (推荐,Docker 内置)
|
||||
docker compose up -d --build
|
||||
|
||||
# 或 Docker Compose V1 (旧版,需要单独安装)
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
#### 4. 拉取嵌入模型
|
||||
|
||||
服务启动后,需要在 Ollama 容器中拉取嵌入模型。推荐使用 `nomic-embed-text-v2-moe`,对中文支持更好:
|
||||
|
||||
```bash
|
||||
# 进入 Ollama 容器拉取模型
|
||||
docker exec -it ai-ollama ollama pull toshk0/nomic-embed-text-v2-moe:Q6_K
|
||||
```
|
||||
|
||||
**可选模型**:
|
||||
|
||||
| 模型 | 维度 | 说明 |
|
||||
|------|------|------|
|
||||
| `toshk0/nomic-embed-text-v2-moe:Q6_K` | 768 | 推荐,中文支持好,支持任务前缀 |
|
||||
| `nomic-embed-text:v1.5` | 768 | 原版,支持任务前缀和 Matryoshka |
|
||||
| `bge-large-zh` | 1024 | 中文专用,效果最好 |
|
||||
|
||||
#### 5. 配置嵌入模型
|
||||
|
||||
访问前端管理界面,进入 **嵌入模型配置** 页面:
|
||||
|
||||
1. 选择提供者:**Nomic Embed (优化版)**
|
||||
2. 配置参数:
|
||||
- **API 地址**:`http://ollama:11434`(Docker 环境)或 `http://localhost:11434`(本地开发)
|
||||
- **模型名称**:`toshk0/nomic-embed-text-v2-moe:Q6_K`
|
||||
- **向量维度**:`768`
|
||||
- **Matryoshka 截断**:`true`
|
||||
3. 点击 **保存配置**
|
||||
|
||||
> **注意**:
|
||||
> - 使用 Nomic Embed (优化版) provider 可启用完整的 RAG 优化功能:任务前缀、Matryoshka 多向量、两阶段检索。
|
||||
> - 嵌入模型配置会持久化保存到 `ai-service/config/embedding_config.json`,服务重启后自动加载。
|
||||
> - **重要**: 切换嵌入模型后,需要删除现有知识库并重新上传文档,因为不同模型生成的向量不兼容。
|
||||
|
||||
#### 6. 验证服务
|
||||
|
||||
```bash
|
||||
# 检查服务状态
|
||||
docker ps
|
||||
|
||||
# 查看后端日志,找到自动生成的 API Key
|
||||
docker logs -f ai-service | grep "Default API Key"
|
||||
```
|
||||
|
||||
> **重要**: 后端首次启动时会自动生成一个默认 API Key,请从日志中复制该 Key,用于前端配置。
|
||||
|
||||
#### 7. 配置前端 API Key
|
||||
|
||||
```bash
|
||||
# 创建前端环境变量文件
|
||||
cd ai-service-admin
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
编辑 `ai-service-admin/.env`,将 `VITE_APP_API_KEY` 设置为后端日志中的 API Key:
|
||||
|
||||
```env
|
||||
VITE_APP_BASE_API=/api
|
||||
VITE_APP_API_KEY=<从后端日志复制的API Key>
|
||||
```
|
||||
|
||||
然后重新构建前端:
|
||||
|
||||
```bash
|
||||
cd ..
|
||||
docker compose up -d --build ai-service-admin
|
||||
```
|
||||
|
||||
#### 7. 访问服务
|
||||
|
||||
| 服务 | 地址 | 说明 |
|
||||
|------|------|------|
|
||||
| 前端管理界面 | http://服务器IP:8181 | Vue 管理后台 |
|
||||
| 后端 API | http://服务器IP:8182 | FastAPI 服务(Java渠道侧调用) |
|
||||
| API 文档 | http://服务器IP:8182/docs | Swagger UI |
|
||||
| Qdrant 控制台 | http://服务器IP:6333/dashboard | 向量数据库管理 |
|
||||
| Ollama API | http://服务器IP:11434 | 嵌入模型服务 |
|
||||
|
||||
> **端口说明**:
|
||||
> - `8181`: 前端管理界面,内部代理后端 API
|
||||
> - `8182`: 后端 API,供 Java 渠道侧直接调用
|
||||
|
||||
## 服务架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 用户访问 │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ ai-service-admin (端口8181) │
|
||||
│ - Nginx 静态文件服务 │
|
||||
│ - 反向代理 /api/* → ai-service:8080 │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ ai-service (端口8080) │
|
||||
│ - FastAPI 后端服务 │
|
||||
│ - RAG / LLM / 知识库管理 │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐
|
||||
│ PostgreSQL │ │ Qdrant │ │ Ollama │
|
||||
│ (端口5432) │ │ (端口6333) │ │ (端口11434) │
|
||||
│ - 会话存储 │ │ - 向量存储 │ │ - nomic-embed │
|
||||
│ - 知识库元数据 │ │ - 文档索引 │ │ - 嵌入模型 │
|
||||
└──────────────────┘ └──────────────────┘ └──────────────────┘
|
||||
```
|
||||
|
||||
## 常用命令
|
||||
|
||||
```bash
|
||||
# 启动所有服务
|
||||
docker compose up -d
|
||||
|
||||
# 重新构建并启动
|
||||
docker compose up -d --build
|
||||
|
||||
# 查看服务状态
|
||||
docker compose ps
|
||||
|
||||
# 查看日志
|
||||
docker compose logs -f ai-service
|
||||
docker compose logs -f ai-service-admin
|
||||
|
||||
# 重启服务
|
||||
docker compose restart ai-service
|
||||
|
||||
# 停止所有服务
|
||||
docker compose down
|
||||
|
||||
# 停止并删除数据卷(清空数据)
|
||||
docker compose down -v
|
||||
```
|
||||
|
||||
## 宿主机 Nginx 配置(可选)
|
||||
|
||||
如果需要通过宿主机 Nginx 统一管理入口(配置域名、SSL证书),可参考 `deploy/nginx.conf.example`:
|
||||
|
||||
```bash
|
||||
# 复制配置文件
|
||||
sudo cp deploy/nginx.conf.example /etc/nginx/conf.d/ai-service.conf
|
||||
|
||||
# 修改配置中的域名
|
||||
sudo vim /etc/nginx/conf.d/ai-service.conf
|
||||
|
||||
# 测试配置
|
||||
sudo nginx -t
|
||||
|
||||
# 重载 Nginx
|
||||
sudo nginx -s reload
|
||||
```
|
||||
|
||||
## 本地开发
|
||||
|
||||
### 后端开发
|
||||
|
||||
```bash
|
||||
cd ai-service
|
||||
|
||||
# 创建虚拟环境
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # Linux/Mac
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# 安装依赖
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# 启动开发服务器
|
||||
uvicorn app.main:app --reload --port 8000
|
||||
```
|
||||
|
||||
### 前端开发
|
||||
|
||||
```bash
|
||||
cd ai-service-admin
|
||||
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 启动开发服务器
|
||||
npm run dev
|
||||
```
|
||||
|
||||
## API 接口
|
||||
|
||||
### 核心接口
|
||||
|
||||
| 接口 | 方法 | 说明 |
|
||||
|------|------|------|
|
||||
| `/ai/chat` | POST | AI 对话接口 |
|
||||
| `/admin/kb` | GET/POST | 知识库管理 |
|
||||
| `/admin/rag/experiments/run` | POST | RAG 实验室 |
|
||||
| `/admin/llm/config` | GET/PUT | LLM 配置 |
|
||||
| `/admin/embedding/config` | GET/PUT | 嵌入模型配置 |
|
||||
|
||||
详细 API 文档请访问 http://服务器IP:8080/docs
|
||||
|
||||
## 环境变量说明
|
||||
|
||||
| 变量名 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `AI_SERVICE_LLM_PROVIDER` | openai | LLM 提供商 |
|
||||
| `AI_SERVICE_LLM_API_KEY` | - | API 密钥 |
|
||||
| `AI_SERVICE_LLM_BASE_URL` | https://api.openai.com/v1 | API 地址 |
|
||||
| `AI_SERVICE_LLM_MODEL` | gpt-4o-mini | 模型名称 |
|
||||
| `AI_SERVICE_DATABASE_URL` | postgresql+asyncpg://... | 数据库连接 |
|
||||
| `AI_SERVICE_QDRANT_URL` | http://qdrant:6333 | Qdrant 地址 |
|
||||
| `AI_SERVICE_LOG_LEVEL` | INFO | 日志级别 |
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
node_modules
|
||||
dist
|
||||
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
*.log
|
||||
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
*.md
|
||||
!README.md
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# API Base URL
|
||||
VITE_APP_BASE_API=/api
|
||||
|
||||
# Default API Key for authentication
|
||||
# IMPORTANT: You must set this to a valid API key from the backend
|
||||
# The backend creates a default API key on first startup (check backend logs)
|
||||
# Or you can create one via the API: POST /admin/api-keys
|
||||
VITE_APP_API_KEY=your-api-key-here
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# AI Service Admin Frontend Dockerfile
|
||||
FROM docker.1ms.run/node:20-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG VITE_APP_API_KEY
|
||||
ARG VITE_APP_BASE_API=/api
|
||||
|
||||
ENV VITE_APP_API_KEY=$VITE_APP_API_KEY
|
||||
ENV VITE_APP_BASE_API=$VITE_APP_BASE_API
|
||||
|
||||
COPY package*.json ./
|
||||
|
||||
RUN npm install && npm install @rollup/rollup-linux-x64-musl --save-optional
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN npm run build
|
||||
|
||||
FROM docker.1ms.run/nginx:alpine
|
||||
|
||||
COPY --from=builder /app/dist /usr/share/nginx/html
|
||||
|
||||
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
server {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
|
||||
location / {
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
location /api/ {
|
||||
proxy_pass http://ai-service:8080/;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
proxy_read_timeout 300s;
|
||||
proxy_connect_timeout 75s;
|
||||
proxy_buffering off;
|
||||
}
|
||||
|
||||
gzip on;
|
||||
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
|
||||
gzip_min_length 1000;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -17,8 +17,8 @@
|
|||
},
|
||||
"devDependencies": {
|
||||
"@vitejs/plugin-vue": "^5.0.4",
|
||||
"typescript": "^5.2.2",
|
||||
"typescript": "~5.6.0",
|
||||
"vite": "^5.1.4",
|
||||
"vue-tsc": "^1.8.27"
|
||||
"vue-tsc": "^2.1.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,6 @@ const isValidTenantId = (tenantId: string): boolean => {
|
|||
const fetchTenantList = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
// 检查当前租户ID格式是否有效
|
||||
if (!isValidTenantId(currentTenantId.value)) {
|
||||
console.warn('Invalid tenant ID format, resetting to default:', currentTenantId.value)
|
||||
currentTenantId.value = 'default@ash@2026'
|
||||
|
|
@ -108,7 +107,6 @@ const fetchTenantList = async () => {
|
|||
const response = await getTenantList()
|
||||
tenantList.value = response.tenants || []
|
||||
|
||||
// 如果当前租户不在列表中,默认选择第一个
|
||||
if (tenantList.value.length > 0 && !tenantList.value.find(t => t.id === currentTenantId.value)) {
|
||||
const firstTenant = tenantList.value[0]
|
||||
currentTenantId.value = firstTenant.id
|
||||
|
|
@ -117,8 +115,7 @@ const fetchTenantList = async () => {
|
|||
} catch (error) {
|
||||
ElMessage.error('获取租户列表失败')
|
||||
console.error('Failed to fetch tenant list:', error)
|
||||
// 失败时使用默认租户
|
||||
tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)' }]
|
||||
tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)', displayName: 'default', year: '2026', createdAt: new Date().toISOString() }]
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ export interface TenantListResponse {
|
|||
total: number
|
||||
}
|
||||
|
||||
export function getTenantList() {
|
||||
export function getTenantList(): Promise<TenantListResponse> {
|
||||
return request<TenantListResponse>({
|
||||
url: '/admin/tenants',
|
||||
method: 'get'
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ const emit = defineEmits<{
|
|||
|
||||
const formRef = ref<FormInstance>()
|
||||
const formData = ref<Record<string, any>>({})
|
||||
const isUpdating = ref(false)
|
||||
|
||||
const schemaProperties = computed(() => {
|
||||
return props.schema?.properties || {}
|
||||
|
|
@ -173,8 +174,11 @@ const initFormData = () => {
|
|||
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
() => {
|
||||
(newVal) => {
|
||||
if (isUpdating.value) return
|
||||
if (JSON.stringify(newVal) !== JSON.stringify(formData.value)) {
|
||||
initFormData()
|
||||
}
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
|
@ -190,7 +194,14 @@ watch(
|
|||
watch(
|
||||
formData,
|
||||
(val) => {
|
||||
if (isUpdating.value) return
|
||||
if (JSON.stringify(val) !== JSON.stringify(props.modelValue)) {
|
||||
isUpdating.value = true
|
||||
emit('update:modelValue', val)
|
||||
Promise.resolve().then(() => {
|
||||
isUpdating.value = false
|
||||
})
|
||||
}
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ const emit = defineEmits<{
|
|||
|
||||
const formRef = ref<FormInstance>()
|
||||
const formData = ref<Record<string, any>>({})
|
||||
const isUpdating = ref(false)
|
||||
|
||||
const schemaProperties = computed(() => {
|
||||
return props.schema?.properties || {}
|
||||
|
|
@ -173,8 +174,11 @@ const initFormData = () => {
|
|||
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
() => {
|
||||
(newVal) => {
|
||||
if (isUpdating.value) return
|
||||
if (JSON.stringify(newVal) !== JSON.stringify(formData.value)) {
|
||||
initFormData()
|
||||
}
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
|
@ -190,7 +194,14 @@ watch(
|
|||
watch(
|
||||
formData,
|
||||
(val) => {
|
||||
if (isUpdating.value) return
|
||||
if (JSON.stringify(val) !== JSON.stringify(props.modelValue)) {
|
||||
isUpdating.value = true
|
||||
emit('update:modelValue', val)
|
||||
Promise.resolve().then(() => {
|
||||
isUpdating.value = false
|
||||
})
|
||||
}
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,8 @@ export const useEmbeddingStore = defineStore('embedding', () => {
|
|||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
}
|
||||
await saveConfig(updateData)
|
||||
const response = await saveConfig(updateData)
|
||||
return response
|
||||
} catch (error) {
|
||||
console.error('Failed to save config:', error)
|
||||
throw error
|
||||
|
|
|
|||
|
|
@ -1,21 +1,22 @@
|
|||
import axios from 'axios'
|
||||
import axios, { type AxiosRequestConfig } from 'axios'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
|
||||
// 创建 axios 实例
|
||||
const service = axios.create({
|
||||
baseURL: import.meta.env.VITE_APP_BASE_API || '/api',
|
||||
timeout: 60000
|
||||
})
|
||||
|
||||
// 请求拦截器
|
||||
service.interceptors.request.use(
|
||||
(config) => {
|
||||
const tenantStore = useTenantStore()
|
||||
if (tenantStore.currentTenantId) {
|
||||
config.headers['X-Tenant-Id'] = tenantStore.currentTenantId
|
||||
}
|
||||
// TODO: 如果有 token 也可以在这里注入 Authorization
|
||||
const apiKey = import.meta.env.VITE_APP_API_KEY
|
||||
if (apiKey) {
|
||||
config.headers['X-API-Key'] = apiKey
|
||||
}
|
||||
return config
|
||||
},
|
||||
(error) => {
|
||||
|
|
@ -24,11 +25,9 @@ service.interceptors.request.use(
|
|||
}
|
||||
)
|
||||
|
||||
// 响应拦截器
|
||||
service.interceptors.response.use(
|
||||
(response) => {
|
||||
const res = response.data
|
||||
// 这里可以根据后端的 code 进行统一处理
|
||||
return res
|
||||
},
|
||||
(error) => {
|
||||
|
|
@ -42,7 +41,6 @@ service.interceptors.response.use(
|
|||
cancelButtonText: '取消',
|
||||
type: 'warning'
|
||||
}).then(() => {
|
||||
// TODO: 跳转到登录页或执行退出逻辑
|
||||
location.href = '/login'
|
||||
})
|
||||
} else if (status === 403) {
|
||||
|
|
@ -69,4 +67,13 @@ service.interceptors.response.use(
|
|||
}
|
||||
)
|
||||
|
||||
export default service
|
||||
interface RequestConfig extends AxiosRequestConfig {
|
||||
url: string
|
||||
method?: string
|
||||
}
|
||||
|
||||
function request<T = any>(config: RequestConfig): Promise<T> {
|
||||
return service.request<any, T>(config)
|
||||
}
|
||||
|
||||
export default request
|
||||
|
|
|
|||
|
|
@ -169,8 +169,19 @@ const handleSave = async () => {
|
|||
|
||||
saving.value = true
|
||||
try {
|
||||
await embeddingStore.saveCurrentConfig()
|
||||
const response: any = await embeddingStore.saveCurrentConfig()
|
||||
ElMessage.success('配置保存成功')
|
||||
|
||||
if (response?.warning || response?.requires_reindex) {
|
||||
ElMessageBox.alert(
|
||||
response.warning || '嵌入模型已更改,请重新上传文档以确保检索效果正常。',
|
||||
'重要提示',
|
||||
{
|
||||
confirmButtonText: '我知道了',
|
||||
type: 'warning',
|
||||
}
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('配置保存失败')
|
||||
} finally {
|
||||
|
|
|
|||
|
|
@ -102,10 +102,17 @@ interface DocumentItem {
|
|||
createTime: string
|
||||
}
|
||||
|
||||
interface IndexJob {
|
||||
jobId: string
|
||||
status: string
|
||||
progress: number
|
||||
errorMsg?: string
|
||||
}
|
||||
|
||||
const tableData = ref<DocumentItem[]>([])
|
||||
const loading = ref(false)
|
||||
const jobDialogVisible = ref(false)
|
||||
const currentJob = ref<any>(null)
|
||||
const currentJob = ref<IndexJob | null>(null)
|
||||
const pollingJobs = ref<Set<string>>(new Set())
|
||||
let pollingInterval: number | null = null
|
||||
|
||||
|
|
@ -150,10 +157,15 @@ const fetchDocuments = async () => {
|
|||
}
|
||||
}
|
||||
|
||||
const fetchJobStatus = async (jobId: string) => {
|
||||
const fetchJobStatus = async (jobId: string): Promise<IndexJob | null> => {
|
||||
try {
|
||||
const res = await getIndexJob(jobId)
|
||||
return res
|
||||
const res: any = await getIndexJob(jobId)
|
||||
return {
|
||||
jobId: res.jobId || jobId,
|
||||
status: res.status || 'pending',
|
||||
progress: res.progress || 0,
|
||||
errorMsg: res.errorMsg
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch job status:', error)
|
||||
return null
|
||||
|
|
@ -246,19 +258,21 @@ const handleFileChange = async (event: Event) => {
|
|||
|
||||
try {
|
||||
loading.value = true
|
||||
const res = await uploadDocument(formData)
|
||||
ElMessage.success(`文档上传成功!任务ID: ${res.jobId}`)
|
||||
const res: any = await uploadDocument(formData)
|
||||
const jobId = res.jobId as string
|
||||
ElMessage.success(`文档上传成功!任务ID: ${jobId}`)
|
||||
console.log('Upload response:', res)
|
||||
|
||||
const newDoc: DocumentItem = {
|
||||
docId: res.docId || '',
|
||||
name: file.name,
|
||||
status: res.status || 'pending',
|
||||
jobId: res.jobId,
|
||||
status: (res.status as string) || 'pending',
|
||||
jobId: jobId,
|
||||
createTime: new Date().toLocaleString('zh-CN')
|
||||
}
|
||||
tableData.value.unshift(newDoc)
|
||||
|
||||
startPolling(res.jobId)
|
||||
startPolling(jobId)
|
||||
} catch (error) {
|
||||
ElMessage.error('文档上传失败')
|
||||
console.error('Upload error:', error)
|
||||
|
|
|
|||
|
|
@ -327,7 +327,7 @@ const runStreamExperiment = async () => {
|
|||
} else if (parsed.type === 'error') {
|
||||
streamError.value = parsed.message || '流式输出错误'
|
||||
streaming.value = false
|
||||
ElMessage.error(streamError.value)
|
||||
ElMessage.error(streamError.value || '未知错误')
|
||||
}
|
||||
} catch {
|
||||
streamContent.value += data
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
/// <reference types="vite/client" />
|
||||
|
||||
interface ImportMetaEnv {
|
||||
readonly VITE_APP_BASE_API: string
|
||||
readonly VITE_APP_API_KEY: string
|
||||
}
|
||||
|
||||
interface ImportMeta {
|
||||
readonly env: ImportMetaEnv
|
||||
}
|
||||
|
|
@ -15,7 +15,8 @@
|
|||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": ["src/*"]
|
||||
}
|
||||
},
|
||||
"types": ["vite/client"]
|
||||
},
|
||||
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"],
|
||||
"references": [{ "path": "./tsconfig.node.json" }]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
__pycache__
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
.pytest_cache
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.hypothesis/
|
||||
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
*.log
|
||||
*.pot
|
||||
*.pyc
|
||||
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
tests/
|
||||
scripts/
|
||||
*.md
|
||||
!README.md
|
||||
|
||||
.git
|
||||
.gitignore
|
||||
.gitea
|
||||
|
||||
check_qdrant.py
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
# AI Service Backend Dockerfile
|
||||
FROM docker.1ms.run/python:3.11-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
COPY pyproject.toml README.md ./
|
||||
|
||||
RUN uv pip install --system --no-cache-dir .
|
||||
|
||||
FROM docker.1ms.run/python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
|
||||
|
||||
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
COPY app ./app
|
||||
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
"""
|
||||
Admin API routes for AI Service management.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08] Admin management endpoints.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
|
||||
"""
|
||||
|
||||
from app.api.admin.api_key import router as api_key_router
|
||||
from app.api.admin.dashboard import router as dashboard_router
|
||||
from app.api.admin.embedding import router as embedding_router
|
||||
from app.api.admin.kb import router as kb_router
|
||||
|
|
@ -11,4 +12,4 @@ from app.api.admin.rag import router as rag_router
|
|||
from app.api.admin.sessions import router as sessions_router
|
||||
from app.api.admin.tenants import router as tenants_router
|
||||
|
||||
__all__ = ["dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"]
|
||||
__all__ = ["api_key_router", "dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,154 @@
|
|||
"""
|
||||
API Key management endpoints.
|
||||
[AC-AISVC-50] CRUD operations for API keys.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.models.entities import ApiKey, ApiKeyCreate
|
||||
from app.services.api_key import get_api_key_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/api-keys", tags=["API Keys"])
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
"""Response model for API key."""
|
||||
|
||||
id: str = Field(..., description="API key ID")
|
||||
key: str = Field(..., description="API key value")
|
||||
name: str = Field(..., description="API key name")
|
||||
is_active: bool = Field(..., description="Whether the key is active")
|
||||
created_at: str = Field(..., description="Creation time")
|
||||
updated_at: str = Field(..., description="Last update time")
|
||||
|
||||
|
||||
class ApiKeyListResponse(BaseModel):
|
||||
"""Response model for API key list."""
|
||||
|
||||
keys: list[ApiKeyResponse] = Field(..., description="List of API keys")
|
||||
total: int = Field(..., description="Total count")
|
||||
|
||||
|
||||
class CreateApiKeyRequest(BaseModel):
|
||||
"""Request model for creating API key."""
|
||||
|
||||
name: str = Field(..., description="API key name/description")
|
||||
key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)")
|
||||
|
||||
|
||||
class ToggleApiKeyRequest(BaseModel):
|
||||
"""Request model for toggling API key status."""
|
||||
|
||||
is_active: bool = Field(..., description="New active status")
|
||||
|
||||
|
||||
def api_key_to_response(api_key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert ApiKey entity to response model."""
|
||||
return ApiKeyResponse(
|
||||
id=str(api_key.id),
|
||||
key=api_key.key,
|
||||
name=api_key.name,
|
||||
is_active=api_key.is_active,
|
||||
created_at=api_key.created_at.isoformat(),
|
||||
updated_at=api_key.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=ApiKeyListResponse)
|
||||
async def list_api_keys(
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-50] List all API keys.
|
||||
"""
|
||||
service = get_api_key_service()
|
||||
keys = await service.list_keys(session)
|
||||
|
||||
return ApiKeyListResponse(
|
||||
keys=[api_key_to_response(k) for k in keys],
|
||||
total=len(keys),
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ApiKeyResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_api_key(
|
||||
request: CreateApiKeyRequest,
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-50] Create a new API key.
|
||||
"""
|
||||
service = get_api_key_service()
|
||||
|
||||
key_value = request.key or service.generate_key()
|
||||
|
||||
key_create = ApiKeyCreate(
|
||||
key=key_value,
|
||||
name=request.name,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
api_key = await service.create_key(session, key_create)
|
||||
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
|
||||
|
||||
return api_key_to_response(api_key)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_api_key(
|
||||
key_id: str,
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-50] Delete an API key.
|
||||
"""
|
||||
service = get_api_key_service()
|
||||
|
||||
deleted = await service.delete_key(session, key_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{key_id}/toggle", response_model=ApiKeyResponse)
|
||||
async def toggle_api_key(
|
||||
key_id: str,
|
||||
request: ToggleApiKeyRequest,
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-50] Toggle API key active status.
|
||||
"""
|
||||
service = get_api_key_service()
|
||||
|
||||
api_key = await service.toggle_key(session, key_id, request.is_active)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found",
|
||||
)
|
||||
|
||||
return api_key_to_response(api_key)
|
||||
|
||||
|
||||
@router.post("/reload-cache", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def reload_api_key_cache(
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-50] Reload API key cache from database.
|
||||
"""
|
||||
service = get_api_key_service()
|
||||
await service.reload_cache(session)
|
||||
|
|
@ -78,12 +78,32 @@ async def update_embedding_config(
|
|||
|
||||
manager = get_embedding_config_manager()
|
||||
|
||||
old_config = manager.get_full_config()
|
||||
old_provider = old_config.get("provider")
|
||||
old_model = old_config.get("config", {}).get("model", "")
|
||||
|
||||
new_model = config.get("model", "")
|
||||
|
||||
try:
|
||||
await manager.update_config(provider, config)
|
||||
return {
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"message": f"Configuration updated to use {provider}",
|
||||
}
|
||||
|
||||
if old_provider != provider or old_model != new_model:
|
||||
response["warning"] = (
|
||||
"嵌入模型已更改。由于不同模型生成的向量不兼容,"
|
||||
"请删除现有知识库并重新上传文档,以确保检索效果正常。"
|
||||
)
|
||||
response["requires_reindex"] = True
|
||||
logger.warning(
|
||||
f"[EMBEDDING] Model changed from {old_provider}/{old_model} to {provider}/{new_model}. "
|
||||
f"Documents need to be re-uploaded."
|
||||
)
|
||||
|
||||
return response
|
||||
except EmbeddingException as e:
|
||||
raise InvalidRequestException(str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -442,13 +442,15 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
|||
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
|
||||
|
||||
qdrant = await get_qdrant_client()
|
||||
await qdrant.ensure_collection_exists(tenant_id)
|
||||
await qdrant.ensure_collection_exists(tenant_id, use_multi_vector=True)
|
||||
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
|
||||
logger.info(f"[INDEX] Using multi-vector format: {use_multi_vector}")
|
||||
|
||||
points = []
|
||||
total_chunks = len(all_chunks)
|
||||
for i, chunk in enumerate(all_chunks):
|
||||
embedding = await embedding_provider.embed(chunk.text)
|
||||
|
||||
payload = {
|
||||
"text": chunk.text,
|
||||
"source": doc_id,
|
||||
|
|
@ -461,6 +463,19 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
|||
if chunk.source:
|
||||
payload["filename"] = chunk.source
|
||||
|
||||
if use_multi_vector:
|
||||
embedding_result = await embedding_provider.embed_document(chunk.text)
|
||||
points.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": {
|
||||
"full": embedding_result.embedding_full,
|
||||
"dim_256": embedding_result.embedding_256,
|
||||
"dim_512": embedding_result.embedding_512,
|
||||
},
|
||||
"payload": payload,
|
||||
})
|
||||
else:
|
||||
embedding = await embedding_provider.embed(chunk.text)
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
|
|
@ -478,6 +493,9 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
|||
|
||||
if points:
|
||||
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
|
||||
if use_multi_vector:
|
||||
await qdrant.upsert_multi_vector(tenant_id, points)
|
||||
else:
|
||||
await qdrant.upsert_vectors(tenant_id, points)
|
||||
|
||||
await kb_service.update_job_status(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Middleware for AI Service.
|
||||
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
|
||||
[AC-AISVC-10, AC-AISVC-12, AC-AISVC-50] X-Tenant-Id header validation, tenant context injection, and API Key authentication.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -17,12 +17,20 @@ from app.core.tenant import clear_tenant_context, set_tenant_context
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
TENANT_ID_HEADER = "X-Tenant-Id"
|
||||
API_KEY_HEADER = "X-API-Key"
|
||||
ACCEPT_HEADER = "Accept"
|
||||
SSE_CONTENT_TYPE = "text/event-stream"
|
||||
|
||||
# Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
|
||||
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
|
||||
|
||||
PATHS_SKIP_API_KEY = {
|
||||
"/health",
|
||||
"/ai/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
}
|
||||
|
||||
|
||||
def validate_tenant_id_format(tenant_id: str) -> bool:
|
||||
"""
|
||||
|
|
@ -41,6 +49,59 @@ def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
|
|||
return parts[0], parts[2]
|
||||
|
||||
|
||||
class ApiKeyMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
[AC-AISVC-50] Middleware to validate API Key for all requests.
|
||||
|
||||
Features:
|
||||
- Validates X-API-Key header against in-memory cache
|
||||
- Skips validation for health/docs endpoints
|
||||
- Returns 401 for missing or invalid API key
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if self._should_skip_api_key(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get(API_KEY_HEADER)
|
||||
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.UNAUTHORIZED.value,
|
||||
message="Missing required header: X-API-Key",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
api_key = api_key.strip()
|
||||
|
||||
from app.services.api_key import get_api_key_service
|
||||
service = get_api_key_service()
|
||||
|
||||
if not service.validate_key(api_key):
|
||||
logger.warning(f"[AC-AISVC-50] Invalid API key for {request.url.path}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.UNAUTHORIZED.value,
|
||||
message="Invalid API key",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_api_key(self, path: str) -> bool:
|
||||
"""Check if the path should skip API key validation."""
|
||||
if path in PATHS_SKIP_API_KEY:
|
||||
return True
|
||||
for skip_path in PATHS_SKIP_API_KEY:
|
||||
if path.startswith(skip_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TenantContextMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
||||
|
|
@ -51,7 +112,7 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
clear_tenant_context()
|
||||
|
||||
if request.url.path == "/ai/health":
|
||||
if request.url.path in ("/health", "/ai/health"):
|
||||
return await call_next(request)
|
||||
|
||||
tenant_id = request.headers.get(TENANT_ID_HEADER)
|
||||
|
|
@ -68,7 +129,6 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
tenant_id = tenant_id.strip()
|
||||
|
||||
# Validate tenant ID format
|
||||
if not validate_tenant_id_format(tenant_id):
|
||||
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
|
||||
return JSONResponse(
|
||||
|
|
@ -79,13 +139,11 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
# Auto-create tenant if not exists (for admin endpoints)
|
||||
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
|
||||
try:
|
||||
await self._ensure_tenant_exists(request, tenant_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
|
||||
# Continue processing even if tenant creation fails
|
||||
|
||||
set_tenant_context(tenant_id)
|
||||
request.state.tenant_id = tenant_id
|
||||
|
|
@ -112,7 +170,6 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
name, year = parse_tenant_id(tenant_id)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Check if tenant exists
|
||||
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
|
||||
result = await session.execute(stmt)
|
||||
existing_tenant = result.scalar_one_or_none()
|
||||
|
|
@ -121,7 +178,6 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
|
||||
return
|
||||
|
||||
# Create new tenant
|
||||
new_tenant = Tenant(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig
|
||||
from qdrant_client.models import Distance, PointStruct, VectorParams, QueryRequest
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
|
@ -61,8 +61,7 @@ class QdrantClient:
|
|||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
exists = any(c.name == collection_name for c in collections.collections)
|
||||
exists = await client.collection_exists(collection_name)
|
||||
|
||||
if not exists:
|
||||
if use_multi_vector:
|
||||
|
|
@ -176,6 +175,7 @@ class QdrantClient:
|
|||
limit: int = 5,
|
||||
score_threshold: float | None = None,
|
||||
vector_name: str = "full",
|
||||
with_vectors: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||
|
|
@ -189,6 +189,7 @@ class QdrantClient:
|
|||
score_threshold: Minimum score threshold for results
|
||||
vector_name: Name of the vector to search (for multi-vector collections)
|
||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||
with_vectors: Whether to return vectors in results (for two-stage reranking)
|
||||
"""
|
||||
client = await self.get_client()
|
||||
|
||||
|
|
@ -211,39 +212,50 @@ class QdrantClient:
|
|||
try:
|
||||
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
||||
|
||||
exists = await client.collection_exists(collection_name)
|
||||
if not exists:
|
||||
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
|
||||
continue
|
||||
|
||||
try:
|
||||
results = await client.search(
|
||||
results = await client.query_points(
|
||||
collection_name=collection_name,
|
||||
query_vector=(vector_name, query_vector),
|
||||
query=query_vector,
|
||||
using=vector_name,
|
||||
limit=limit,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
|
||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
|
||||
f"trying without vector name (single-vector mode)"
|
||||
)
|
||||
results = await client.search(
|
||||
results = await client.query_points(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query=query_vector,
|
||||
limit=limit,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned {len(results.points)} raw results"
|
||||
)
|
||||
|
||||
hits = [
|
||||
{
|
||||
hits = []
|
||||
for result in results.points:
|
||||
hit = {
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
if score_threshold is None or result.score >= score_threshold
|
||||
]
|
||||
if with_vectors and result.vector:
|
||||
hit["vector"] = result.vector
|
||||
hits.append(hit)
|
||||
all_hits.extend(hits)
|
||||
|
||||
if hits:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api import chat_router, health_router
|
||||
from app.api.admin import dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router
|
||||
from app.api.admin import api_key_router, dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router
|
||||
from app.api.admin.kb_optimized import router as kb_optimized_router
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import close_db, init_db
|
||||
|
|
@ -24,7 +24,7 @@ from app.core.exceptions import (
|
|||
generic_exception_handler,
|
||||
http_exception_handler,
|
||||
)
|
||||
from app.core.middleware import TenantContextMiddleware
|
||||
from app.core.middleware import ApiKeyMiddleware, TenantContextMiddleware
|
||||
from app.core.qdrant_client import close_qdrant_client
|
||||
|
||||
settings = get_settings()
|
||||
|
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-11] Application lifespan manager.
|
||||
[AC-AISVC-01, AC-AISVC-11, AC-AISVC-50] Application lifespan manager.
|
||||
Handles startup and shutdown of database and external connections.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}")
|
||||
|
|
@ -51,6 +51,19 @@ async def lifespan(app: FastAPI):
|
|||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}")
|
||||
|
||||
try:
|
||||
from app.core.database import async_session_maker
|
||||
from app.services.api_key import get_api_key_service
|
||||
|
||||
async with async_session_maker() as session:
|
||||
api_key_service = get_api_key_service()
|
||||
await api_key_service.initialize(session)
|
||||
default_key = await api_key_service.create_default_key(session)
|
||||
if default_key:
|
||||
logger.info(f"[AC-AISVC-50] Default API key created: {default_key.key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-50] API key initialization skipped: {e}")
|
||||
|
||||
yield
|
||||
|
||||
await close_db()
|
||||
|
|
@ -87,6 +100,7 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
app.add_middleware(TenantContextMiddleware)
|
||||
app.add_middleware(ApiKeyMiddleware)
|
||||
|
||||
app.add_exception_handler(AIServiceException, ai_service_exception_handler)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
|
|
@ -113,6 +127,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|||
app.include_router(health_router)
|
||||
app.include_router(chat_router)
|
||||
|
||||
app.include_router(api_key_router)
|
||||
app.include_router(dashboard_router)
|
||||
app.include_router(embedding_router)
|
||||
app.include_router(kb_router)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ class ErrorCode(str, Enum):
|
|||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
MISSING_TENANT_ID = "MISSING_TENANT_ID"
|
||||
INVALID_TENANT_ID = "INVALID_TENANT_ID"
|
||||
UNAUTHORIZED = "UNAUTHORIZED"
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
|
|
|
|||
|
|
@ -198,3 +198,27 @@ class DocumentCreate(SQLModel):
|
|||
file_path: str | None = None
|
||||
file_size: int | None = None
|
||||
file_type: str | None = None
|
||||
|
||||
|
||||
class ApiKey(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-50] API Key entity for lightweight authentication.
|
||||
Keys are loaded into memory on startup for fast validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
key: str = Field(..., description="API Key (unique)", unique=True, index=True)
|
||||
name: str = Field(..., description="Key name/description for identification")
|
||||
is_active: bool = Field(default=True, description="Whether the key is active")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class ApiKeyCreate(SQLModel):
|
||||
"""Schema for creating a new API key."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
is_active: bool = True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
"""
|
||||
API Key management service.
|
||||
[AC-AISVC-50] Lightweight authentication with in-memory cache.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import ApiKey, ApiKeyCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
"""
|
||||
[AC-AISVC-50] API Key management service.
|
||||
|
||||
Features:
|
||||
- In-memory cache for fast validation
|
||||
- Database persistence
|
||||
- Hot-reload support
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._keys_cache: set[str] = set()
|
||||
self._initialized: bool = False
|
||||
|
||||
async def initialize(self, session: AsyncSession) -> None:
|
||||
"""
|
||||
Load all active API keys from database into memory.
|
||||
Should be called on application startup.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(ApiKey).where(ApiKey.is_active == True)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
self._keys_cache = {key.key for key in keys}
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
|
||||
|
||||
def validate_key(self, key: str) -> bool:
|
||||
"""
|
||||
Validate an API key against the in-memory cache.
|
||||
|
||||
Args:
|
||||
key: The API key to validate
|
||||
|
||||
Returns:
|
||||
True if the key is valid, False otherwise
|
||||
"""
|
||||
if not self._initialized:
|
||||
logger.warning("[AC-AISVC-50] API key service not initialized")
|
||||
return False
|
||||
|
||||
return key in self._keys_cache
|
||||
|
||||
def generate_key(self) -> str:
|
||||
"""
|
||||
Generate a new secure API key.
|
||||
|
||||
Returns:
|
||||
A URL-safe random string
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
async def create_key(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
key_create: ApiKeyCreate
|
||||
) -> ApiKey:
|
||||
"""
|
||||
Create a new API key.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
key_create: Key creation data
|
||||
|
||||
Returns:
|
||||
The created ApiKey entity
|
||||
"""
|
||||
api_key = ApiKey(
|
||||
key=key_create.key,
|
||||
name=key_create.name,
|
||||
is_active=key_create.is_active,
|
||||
)
|
||||
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
|
||||
if api_key.is_active:
|
||||
self._keys_cache.add(api_key.key)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
|
||||
return api_key
|
||||
|
||||
async def create_default_key(self, session: AsyncSession) -> Optional[ApiKey]:
|
||||
"""
|
||||
Create a default API key if none exists.
|
||||
|
||||
Returns:
|
||||
The created ApiKey or None if keys already exist
|
||||
"""
|
||||
result = await session.execute(select(ApiKey).limit(1))
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
return None
|
||||
|
||||
default_key = secrets.token_urlsafe(32)
|
||||
api_key = ApiKey(
|
||||
key=default_key,
|
||||
name="Default API Key",
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
|
||||
self._keys_cache.add(api_key.key)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
|
||||
return api_key
|
||||
|
||||
async def delete_key(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
key_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete an API key.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
key_id: The key ID to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
import uuid
|
||||
|
||||
try:
|
||||
key_uuid = uuid.UUID(key_id)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
result = await session.execute(
|
||||
select(ApiKey).where(ApiKey.id == key_uuid)
|
||||
)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
key_value = api_key.key
|
||||
await session.delete(api_key)
|
||||
await session.commit()
|
||||
|
||||
self._keys_cache.discard(key_value)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
|
||||
return True
|
||||
|
||||
async def toggle_key(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
key_id: str,
|
||||
is_active: bool
|
||||
) -> Optional[ApiKey]:
|
||||
"""
|
||||
Toggle API key active status.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
key_id: The key ID to toggle
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
The updated ApiKey or None if not found
|
||||
"""
|
||||
import uuid
|
||||
|
||||
try:
|
||||
key_uuid = uuid.UUID(key_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
result = await session.execute(
|
||||
select(ApiKey).where(ApiKey.id == key_uuid)
|
||||
)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
api_key.is_active = is_active
|
||||
api_key.updated_at = datetime.utcnow()
|
||||
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
|
||||
if is_active:
|
||||
self._keys_cache.add(api_key.key)
|
||||
else:
|
||||
self._keys_cache.discard(api_key.key)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
|
||||
return api_key
|
||||
|
||||
async def list_keys(self, session: AsyncSession) -> list[ApiKey]:
|
||||
"""
|
||||
List all API keys.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
List of all ApiKey entities
|
||||
"""
|
||||
result = await session.execute(select(ApiKey))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def reload_cache(self, session: AsyncSession) -> None:
|
||||
"""
|
||||
Reload all API keys from database into memory.
|
||||
"""
|
||||
self._keys_cache.clear()
|
||||
await self.initialize(session)
|
||||
logger.info("[AC-AISVC-50] API key cache reloaded")
|
||||
|
||||
|
||||
_api_key_service: ApiKeyService | None = None
|
||||
|
||||
|
||||
def get_api_key_service() -> ApiKeyService:
|
||||
"""Get the global API key service instance."""
|
||||
global _api_key_service
|
||||
if _api_key_service is None:
|
||||
_api_key_service = ApiKeyService()
|
||||
return _api_key_service
|
||||
|
|
@ -7,7 +7,9 @@ Design reference: progress.md Section 7.1 - Architecture
|
|||
- EmbeddingConfigManager: manages configuration with hot-reload support
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Type
|
||||
|
||||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||
|
|
@ -17,6 +19,8 @@ from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
|
||||
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
"""
|
||||
|
|
@ -74,11 +78,38 @@ class EmbeddingProviderFactory:
|
|||
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
||||
}
|
||||
|
||||
raw_schema = temp_instance.get_config_schema()
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
for key, field in raw_schema.items():
|
||||
properties[key] = {
|
||||
"type": field.get("type", "string"),
|
||||
"title": field.get("title", key),
|
||||
"description": field.get("description", ""),
|
||||
"default": field.get("default"),
|
||||
}
|
||||
if field.get("enum"):
|
||||
properties[key]["enum"] = field.get("enum")
|
||||
if field.get("minimum") is not None:
|
||||
properties[key]["minimum"] = field.get("minimum")
|
||||
if field.get("maximum") is not None:
|
||||
properties[key]["maximum"] = field.get("maximum")
|
||||
if field.get("required"):
|
||||
required.append(key)
|
||||
|
||||
config_schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if required:
|
||||
config_schema["required"] = required
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"display_name": display_names.get(name, name),
|
||||
"description": descriptions.get(name, ""),
|
||||
"config_schema": temp_instance.get_config_schema(),
|
||||
"config_schema": config_schema,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -125,18 +156,47 @@ class EmbeddingProviderFactory:
|
|||
class EmbeddingConfigManager:
|
||||
"""
|
||||
Manager for embedding configuration.
|
||||
[AC-AISVC-31] Supports hot-reload of configuration.
|
||||
[AC-AISVC-31] Supports hot-reload of configuration with persistence.
|
||||
"""
|
||||
|
||||
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
|
||||
self._provider_name = default_provider
|
||||
self._config = default_config or {
|
||||
self._default_provider = default_provider
|
||||
self._default_config = default_config or {
|
||||
"base_url": "http://localhost:11434",
|
||||
"model": "nomic-embed-text",
|
||||
"dimension": 768,
|
||||
}
|
||||
self._provider_name = default_provider
|
||||
self._config = self._default_config.copy()
|
||||
self._provider: EmbeddingProvider | None = None
|
||||
|
||||
self._load_from_file()
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
if EMBEDDING_CONFIG_FILE.exists():
|
||||
with open(EMBEDDING_CONFIG_FILE, 'r', encoding='utf-8') as f:
|
||||
saved = json.load(f)
|
||||
self._provider_name = saved.get("provider", self._default_provider)
|
||||
self._config = saved.get("config", self._default_config.copy())
|
||||
logger.info(f"Loaded embedding config from file: provider={self._provider_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load embedding config from file: {e}")
|
||||
|
||||
def _save_to_file(self) -> None:
|
||||
"""Save configuration to file."""
|
||||
try:
|
||||
EMBEDDING_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(EMBEDDING_CONFIG_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"provider": self._provider_name,
|
||||
"config": self._config,
|
||||
}, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved embedding config to file: provider={self._provider_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save embedding config to file: {e}")
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Get current provider name."""
|
||||
return self._provider_name
|
||||
|
|
@ -174,7 +234,7 @@ class EmbeddingConfigManager:
|
|||
) -> bool:
|
||||
"""
|
||||
Update embedding configuration.
|
||||
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload.
|
||||
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload with persistence.
|
||||
|
||||
Args:
|
||||
provider: New provider name
|
||||
|
|
@ -202,6 +262,8 @@ class EmbeddingConfigManager:
|
|||
self._config = config
|
||||
self._provider = new_provider_instance
|
||||
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"Updated embedding config: provider={provider}")
|
||||
return True
|
||||
|
||||
|
|
@ -286,7 +348,7 @@ def get_embedding_config_manager() -> EmbeddingConfigManager:
|
|||
settings = get_settings()
|
||||
|
||||
_embedding_config_manager = EmbeddingConfigManager(
|
||||
default_provider="ollama",
|
||||
default_provider="nomic",
|
||||
default_config={
|
||||
"base_url": settings.ollama_base_url,
|
||||
"model": settings.ollama_embedding_model,
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ class NomicEmbeddingProvider(EmbeddingProvider):
|
|||
|
||||
embedding_256 = self._truncate_and_normalize(embedding, 256)
|
||||
embedding_512 = self._truncate_and_normalize(embedding, 512)
|
||||
embedding_full = self._truncate_and_normalize(embedding, len(embedding))
|
||||
|
||||
logger.debug(
|
||||
f"Generated Nomic embedding: task={task.value}, "
|
||||
|
|
@ -156,7 +157,7 @@ class NomicEmbeddingProvider(EmbeddingProvider):
|
|||
)
|
||||
|
||||
return NomicEmbeddingResult(
|
||||
embedding_full=embedding,
|
||||
embedding_full=embedding_full,
|
||||
embedding_256=embedding_256,
|
||||
embedding_512=embedding_512,
|
||||
dimension=len(embedding),
|
||||
|
|
@ -259,26 +260,31 @@ class NomicEmbeddingProvider(EmbeddingProvider):
|
|||
return {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "API 地址",
|
||||
"description": "Ollama API 地址",
|
||||
"default": "http://localhost:11434",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "模型名称",
|
||||
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5)",
|
||||
"default": "nomic-embed-text",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"title": "向量维度",
|
||||
"description": "向量维度(支持 256/512/768)",
|
||||
"default": 768,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"title": "超时时间",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
"enable_matryoshka": {
|
||||
"type": "boolean",
|
||||
"title": "Matryoshka 截断",
|
||||
"description": "启用 Matryoshka 维度截断",
|
||||
"default": True,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -130,21 +130,25 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
|
|||
return {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "API 地址",
|
||||
"description": "Ollama API 地址",
|
||||
"default": "http://localhost:11434",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "模型名称",
|
||||
"description": "嵌入模型名称",
|
||||
"default": "nomic-embed-text",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"title": "向量维度",
|
||||
"description": "向量维度",
|
||||
"default": 768,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"title": "超时时间",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -159,28 +159,33 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|||
return {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"title": "API 密钥",
|
||||
"description": "OpenAI API 密钥",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "模型名称",
|
||||
"description": "嵌入模型名称",
|
||||
"default": "text-embedding-3-small",
|
||||
"enum": list(self.MODEL_DIMENSIONS.keys()),
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "API 地址",
|
||||
"description": "OpenAI API 地址(支持兼容接口)",
|
||||
"default": "https://api.openai.com/v1",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"title": "向量维度",
|
||||
"description": "向量维度(仅 text-embedding-3 系列支持自定义)",
|
||||
"default": 1536,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"title": "超时时间",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ LLM Provider Factory and Configuration Management.
|
|||
Design pattern: Factory pattern for pluggable LLM providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.llm.base import LLMClient, LLMConfig
|
||||
|
|
@ -14,6 +16,8 @@ from app.services.llm.openai_client import OpenAIClient
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLM_CONFIG_FILE = Path("config/llm_config.json")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMProviderInfo:
|
||||
|
|
@ -257,7 +261,7 @@ class LLMProviderFactory:
|
|||
class LLMConfigManager:
|
||||
"""
|
||||
Manager for LLM configuration.
|
||||
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload.
|
||||
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload and persistence.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -275,11 +279,40 @@ class LLMConfigManager:
|
|||
}
|
||||
self._client: LLMClient | None = None
|
||||
|
||||
self._load_from_file()
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
if LLM_CONFIG_FILE.exists():
|
||||
with open(LLM_CONFIG_FILE, 'r', encoding='utf-8') as f:
|
||||
saved = json.load(f)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from file: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from file: {e}")
|
||||
|
||||
def _save_to_file(self) -> None:
|
||||
"""Save configuration to file."""
|
||||
try:
|
||||
LLM_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(LLM_CONFIG_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
}, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[AC-ASA-16] Saved LLM config to file: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-ASA-16] Failed to save LLM config to file: {e}")
|
||||
|
||||
def get_current_config(self) -> dict[str, Any]:
|
||||
"""Get current LLM configuration."""
|
||||
return {
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
"config": self._current_config.copy(),
|
||||
}
|
||||
|
||||
async def update_config(
|
||||
|
|
@ -289,7 +322,7 @@ class LLMConfigManager:
|
|||
) -> bool:
|
||||
"""
|
||||
Update LLM configuration.
|
||||
[AC-ASA-16] Hot-reload configuration.
|
||||
[AC-ASA-16] Hot-reload configuration with persistence.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
|
|
@ -311,6 +344,8 @@ class LLMConfigManager:
|
|||
self._current_provider = provider
|
||||
self._current_config = validated_config
|
||||
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
||||
return True
|
||||
|
||||
|
|
@ -365,7 +400,7 @@ class LLMConfigManager:
|
|||
test_provider = provider or self._current_provider
|
||||
test_config = config if config else self._current_config
|
||||
|
||||
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, config={test_config}")
|
||||
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, model={test_config.get('model')}")
|
||||
|
||||
if test_provider not in LLM_PROVIDERS:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -119,13 +119,7 @@ class OrchestratorService:
|
|||
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
|
||||
enable_rag=True,
|
||||
)
|
||||
self._llm_config = LLMConfig(
|
||||
model=getattr(settings, "llm_model", "gpt-4o-mini"),
|
||||
max_tokens=getattr(settings, "llm_max_tokens", 2048),
|
||||
temperature=getattr(settings, "llm_temperature", 0.7),
|
||||
timeout_seconds=getattr(settings, "llm_timeout_seconds", 30),
|
||||
max_retries=getattr(settings, "llm_max_retries", 3),
|
||||
)
|
||||
self._llm_config: LLMConfig | None = None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
|
|
@ -345,7 +339,6 @@ class OrchestratorService:
|
|||
try:
|
||||
ctx.llm_response = await self._llm_client.generate(
|
||||
messages=messages,
|
||||
config=self._llm_config,
|
||||
)
|
||||
ctx.diagnostics["llm_mode"] = "live"
|
||||
ctx.diagnostics["llm_model"] = ctx.llm_response.model
|
||||
|
|
@ -627,7 +620,7 @@ class OrchestratorService:
|
|||
"""
|
||||
messages = self._build_llm_messages(ctx)
|
||||
|
||||
async for chunk in self._llm_client.stream_generate(messages, self._llm_config):
|
||||
async for chunk in self._llm_client.stream_generate(messages):
|
||||
if not state_machine.can_send_message():
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -84,7 +84,13 @@ class RRFCombiner:
|
|||
"bm25_rank": -1,
|
||||
"payload": result.get("payload", {}),
|
||||
"id": chunk_id,
|
||||
"vector": result.get("vector"),
|
||||
}
|
||||
else:
|
||||
combined_scores[chunk_id]["vector_score"] = result.get("score", 0.0)
|
||||
combined_scores[chunk_id]["vector_rank"] = rank
|
||||
if result.get("vector"):
|
||||
combined_scores[chunk_id]["vector"] = result.get("vector")
|
||||
|
||||
combined_scores[chunk_id]["score"] += rrf_score
|
||||
|
||||
|
|
@ -101,6 +107,7 @@ class RRFCombiner:
|
|||
"bm25_rank": rank,
|
||||
"payload": result.get("payload", {}),
|
||||
"id": chunk_id,
|
||||
"vector": result.get("vector"),
|
||||
}
|
||||
else:
|
||||
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
||||
|
|
@ -131,7 +138,6 @@ class OptimizedRetriever(BaseRetriever):
|
|||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||
top_k: int | None = None,
|
||||
score_threshold: float | None = None,
|
||||
min_hits: int | None = None,
|
||||
|
|
@ -141,7 +147,6 @@ class OptimizedRetriever(BaseRetriever):
|
|||
rrf_k: int | None = None,
|
||||
):
|
||||
self._qdrant_client = qdrant_client
|
||||
self._embedding_provider = embedding_provider
|
||||
self._top_k = top_k or settings.rag_top_k
|
||||
self._score_threshold = score_threshold or settings.rag_score_threshold
|
||||
self._min_hits = min_hits or settings.rag_min_hits
|
||||
|
|
@ -157,19 +162,17 @@ class OptimizedRetriever(BaseRetriever):
|
|||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
from app.services.embedding.factory import get_embedding_config_manager
|
||||
manager = get_embedding_config_manager()
|
||||
provider = await manager.get_provider()
|
||||
if isinstance(provider, NomicEmbeddingProvider):
|
||||
self._embedding_provider = provider
|
||||
return provider
|
||||
else:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
return NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||
"""
|
||||
|
|
@ -199,7 +202,15 @@ class OptimizedRetriever(BaseRetriever):
|
|||
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
||||
)
|
||||
|
||||
if self._two_stage_enabled:
|
||||
if self._two_stage_enabled and self._hybrid_enabled:
|
||||
logger.info("[RAG-OPT] Using two-stage + hybrid retrieval strategy")
|
||||
results = await self._two_stage_hybrid_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result,
|
||||
ctx.query,
|
||||
self._top_k,
|
||||
)
|
||||
elif self._two_stage_enabled:
|
||||
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
||||
results = await self._two_stage_retrieve(
|
||||
ctx.tenant_id,
|
||||
|
|
@ -300,20 +311,27 @@ class OptimizedRetriever(BaseRetriever):
|
|||
stage1_start = time.perf_counter()
|
||||
candidates = await self._search_with_dimension(
|
||||
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||
top_k * self._two_stage_expand_factor
|
||||
top_k * self._two_stage_expand_factor,
|
||||
with_vectors=True,
|
||||
)
|
||||
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
||||
)
|
||||
|
||||
stage2_start = time.perf_counter()
|
||||
reranked = []
|
||||
for candidate in candidates:
|
||||
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
|
||||
if stored_full_embedding:
|
||||
import numpy as np
|
||||
vector_data = candidate.get("vector", {})
|
||||
stored_full_embedding = None
|
||||
|
||||
if isinstance(vector_data, dict):
|
||||
stored_full_embedding = vector_data.get("full", [])
|
||||
elif isinstance(vector_data, list):
|
||||
stored_full_embedding = vector_data
|
||||
|
||||
if stored_full_embedding and len(stored_full_embedding) > 0:
|
||||
similarity = self._cosine_similarity(
|
||||
embedding_result.embedding_full,
|
||||
stored_full_embedding
|
||||
|
|
@ -326,7 +344,7 @@ class OptimizedRetriever(BaseRetriever):
|
|||
results = reranked[:top_k]
|
||||
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
||||
)
|
||||
|
||||
|
|
@ -374,6 +392,92 @@ class OptimizedRetriever(BaseRetriever):
|
|||
|
||||
return combined[:top_k]
|
||||
|
||||
async def _two_stage_hybrid_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding_result: NomicEmbeddingResult,
|
||||
query: str,
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Two-stage + Hybrid retrieval strategy.
|
||||
|
||||
Stage 1: Fast retrieval with 256-dim vectors + BM25 in parallel
|
||||
Stage 2: RRF fusion + Precise reranking with 768-dim vectors
|
||||
|
||||
This combines the best of both worlds:
|
||||
- Two-stage: Speed from 256-dim, precision from 768-dim reranking
|
||||
- Hybrid: Semantic matching from vectors, keyword matching from BM25
|
||||
"""
|
||||
import time
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
stage1_start = time.perf_counter()
|
||||
|
||||
vector_task = self._search_with_dimension(
|
||||
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||
top_k * self._two_stage_expand_factor,
|
||||
with_vectors=True,
|
||||
)
|
||||
|
||||
bm25_task = self._bm25_search(client, tenant_id, query, top_k * self._two_stage_expand_factor)
|
||||
|
||||
vector_results, bm25_results = await asyncio.gather(
|
||||
vector_task, bm25_task, return_exceptions=True
|
||||
)
|
||||
|
||||
if isinstance(vector_results, Exception):
|
||||
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
|
||||
vector_results = []
|
||||
|
||||
if isinstance(bm25_results, Exception):
|
||||
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
|
||||
bm25_results = []
|
||||
|
||||
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||
logger.info(
|
||||
f"[RAG-OPT] Two-stage Hybrid Stage 1: vector={len(vector_results)}, bm25={len(bm25_results)}, latency={stage1_latency:.2f}ms"
|
||||
)
|
||||
|
||||
stage2_start = time.perf_counter()
|
||||
|
||||
combined = self._rrf_combiner.combine(
|
||||
vector_results,
|
||||
bm25_results,
|
||||
vector_weight=settings.rag_vector_weight,
|
||||
bm25_weight=settings.rag_bm25_weight,
|
||||
)
|
||||
|
||||
reranked = []
|
||||
for candidate in combined[:top_k * 2]:
|
||||
vector_data = candidate.get("vector", {})
|
||||
stored_full_embedding = None
|
||||
|
||||
if isinstance(vector_data, dict):
|
||||
stored_full_embedding = vector_data.get("full", [])
|
||||
elif isinstance(vector_data, list):
|
||||
stored_full_embedding = vector_data
|
||||
|
||||
if stored_full_embedding and len(stored_full_embedding) > 0:
|
||||
similarity = self._cosine_similarity(
|
||||
embedding_result.embedding_full,
|
||||
stored_full_embedding
|
||||
)
|
||||
candidate["score"] = similarity
|
||||
candidate["stage"] = "two_stage_hybrid_reranked"
|
||||
reranked.append(candidate)
|
||||
|
||||
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
results = reranked[:top_k]
|
||||
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Two-stage Hybrid Stage 2 (reranking): {len(results)} final results in {stage2_latency:.2f}ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _vector_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
|
@ -393,45 +497,37 @@ class OptimizedRetriever(BaseRetriever):
|
|||
query_vector: list[float],
|
||||
vector_name: str,
|
||||
limit: int,
|
||||
with_vectors: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using specified vector dimension."""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Searching collection={collection_name}, "
|
||||
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
|
||||
f"[RAG-OPT] Searching with vector_name={vector_name}, "
|
||||
f"limit={limit}, vector_dim={len(query_vector)}, with_vectors={with_vectors}"
|
||||
)
|
||||
|
||||
results = await qdrant.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=(vector_name, query_vector),
|
||||
results = await client.search(
|
||||
tenant_id=tenant_id,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
vector_name=vector_name,
|
||||
with_vectors=with_vectors,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
|
||||
f"[RAG-OPT] Search returned {len(results)} results"
|
||||
)
|
||||
|
||||
if len(results) > 0:
|
||||
for i, r in enumerate(results[:3]):
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
|
||||
f"[RAG-OPT] Result {i+1}: id={r['id']}, score={r['score']:.4f}"
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
|
||||
f"collection_name={client.get_collection_name(tenant_id)}",
|
||||
f"[RAG-OPT] Search with {vector_name} failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -14,12 +14,13 @@ dependencies = [
|
|||
"tenacity>=8.2.0",
|
||||
"sqlmodel>=0.0.14",
|
||||
"asyncpg>=0.29.0",
|
||||
"qdrant-client>=1.7.0",
|
||||
"qdrant-client>=1.9.0,<2.0.0",
|
||||
"tiktoken>=0.5.0",
|
||||
"openpyxl>=3.1.0",
|
||||
"python-docx>=1.1.0",
|
||||
"pymupdf>=1.23.0",
|
||||
"pdfplumber>=0.10.0",
|
||||
"python-multipart>=0.0.6",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Script to cleanup Qdrant collections and data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "Q:\\agentProject\\ai-robot-core\\ai-service")
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def list_collections():
|
||||
"""List all collections in Qdrant."""
|
||||
client = await get_qdrant_client()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
collections = await qdrant.get_collections()
|
||||
return [c.name for c in collections.collections]
|
||||
|
||||
|
||||
async def delete_collection(collection_name: str):
|
||||
"""Delete a specific collection."""
|
||||
client = await get_qdrant_client()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
try:
|
||||
await qdrant.delete_collection(collection_name)
|
||||
logger.info(f"Deleted collection: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection {collection_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_all_collections():
|
||||
"""Delete all collections."""
|
||||
collections = await list_collections()
|
||||
logger.info(f"Found {len(collections)} collections: {collections}")
|
||||
|
||||
for name in collections:
|
||||
await delete_collection(name)
|
||||
|
||||
logger.info("All collections deleted")
|
||||
|
||||
|
||||
async def delete_tenant_collection(tenant_id: str):
|
||||
"""Delete collection for a specific tenant."""
|
||||
client = await get_qdrant_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
success = await delete_collection(collection_name)
|
||||
if success:
|
||||
logger.info(f"Deleted collection for tenant: {tenant_id}")
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Cleanup Qdrant data")
|
||||
parser.add_argument("--all", action="store_true", help="Delete all collections")
|
||||
parser.add_argument("--tenant", type=str, help="Delete collection for specific tenant")
|
||||
parser.add_argument("--list", action="store_true", help="List all collections")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list:
|
||||
collections = asyncio.run(list_collections())
|
||||
print(f"Collections: {collections}")
|
||||
elif args.all:
|
||||
confirm = input("Are you sure you want to delete ALL collections? (yes/no): ")
|
||||
if confirm.lower() == "yes":
|
||||
asyncio.run(delete_all_collections())
|
||||
else:
|
||||
print("Cancelled")
|
||||
elif args.tenant:
|
||||
confirm = input(f"Delete collection for tenant '{args.tenant}'? (yes/no): ")
|
||||
if confirm.lower() == "yes":
|
||||
asyncio.run(delete_tenant_collection(args.tenant))
|
||||
else:
|
||||
print("Cancelled")
|
||||
else:
|
||||
parser.print_help()
|
||||
|
|
@ -28,6 +28,13 @@ CREATE TABLE IF NOT EXISTS chat_messages (
|
|||
session_id VARCHAR NOT NULL,
|
||||
role VARCHAR NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
total_tokens INTEGER,
|
||||
latency_ms INTEGER,
|
||||
first_token_ms INTEGER,
|
||||
is_error BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
error_message VARCHAR,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
|
|
@ -74,6 +81,18 @@ CREATE TABLE IF NOT EXISTS index_jobs (
|
|||
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
-- ============================================
|
||||
-- API Keys Table [AC-AISVC-50]
|
||||
-- ============================================
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id UUID NOT NULL PRIMARY KEY,
|
||||
key VARCHAR NOT NULL UNIQUE,
|
||||
name VARCHAR NOT NULL,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
-- ============================================
|
||||
-- Indexes
|
||||
-- ============================================
|
||||
|
|
@ -100,6 +119,10 @@ CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_id ON index_jobs (tenant_id);
|
|||
CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_doc ON index_jobs (tenant_id, doc_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_status ON index_jobs (tenant_id, status);
|
||||
|
||||
-- API Keys Indexes [AC-AISVC-50]
|
||||
CREATE INDEX IF NOT EXISTS ix_api_keys_key ON api_keys (key);
|
||||
CREATE INDEX IF NOT EXISTS ix_api_keys_is_active ON api_keys (is_active);
|
||||
|
||||
-- ============================================
|
||||
-- Verification
|
||||
-- ============================================
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
-- Migration: Add missing columns to chat_messages table
|
||||
-- Execute this on existing database to add new columns
|
||||
|
||||
-- Add token tracking columns
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS prompt_tokens INTEGER;
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS completion_tokens INTEGER;
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS total_tokens INTEGER;
|
||||
|
||||
-- Add latency tracking columns
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS latency_ms INTEGER;
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS first_token_ms INTEGER;
|
||||
|
||||
-- Add error tracking columns
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS is_error BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS error_message VARCHAR;
|
||||
|
||||
-- Create API Keys table if not exists
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id UUID NOT NULL PRIMARY KEY,
|
||||
key VARCHAR NOT NULL UNIQUE,
|
||||
name VARCHAR NOT NULL,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
-- Create API Keys indexes
|
||||
CREATE INDEX IF NOT EXISTS ix_api_keys_key ON api_keys (key);
|
||||
CREATE INDEX IF NOT EXISTS ix_api_keys_is_active ON api_keys (is_active);
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
# AI Service Nginx Configuration
|
||||
# 将此文件放置于 /etc/nginx/conf.d/ai-service.conf
|
||||
# 或 include 到主配置文件中
|
||||
|
||||
# 后端 API 上游(供 Java 渠道侧调用)
|
||||
upstream ai_service_backend {
|
||||
server 127.0.0.1:8182;
|
||||
}
|
||||
|
||||
# 前端管理界面上游
|
||||
upstream ai_service_admin {
|
||||
server 127.0.0.1:8181;
|
||||
}
|
||||
|
||||
# 前端管理界面
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com; # 替换为你的域名或服务器IP
|
||||
|
||||
# 访问日志
|
||||
access_log /var/log/nginx/ai-service-admin.access.log;
|
||||
error_log /var/log/nginx/ai-service-admin.error.log;
|
||||
|
||||
location / {
|
||||
proxy_pass http://ai_service_admin;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
|
||||
# SSE 流式响应支持
|
||||
proxy_read_timeout 300s;
|
||||
proxy_connect_timeout 75s;
|
||||
proxy_buffering off;
|
||||
}
|
||||
}
|
||||
|
||||
# 后端 API(供 Java 渠道侧调用)
|
||||
# 如果使用域名,可以用不同的路径或子域名
|
||||
# 示例:api.your-domain.com 或 your-domain.com/api/
|
||||
server {
|
||||
listen 80;
|
||||
server_name api.your-domain.com; # 替换为 API 子域名
|
||||
|
||||
# 访问日志
|
||||
access_log /var/log/nginx/ai-service-api.access.log;
|
||||
error_log /var/log/nginx/ai-service-api.error.log;
|
||||
|
||||
location / {
|
||||
proxy_pass http://ai_service_backend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
|
||||
# SSE 流式响应支持
|
||||
proxy_read_timeout 300s;
|
||||
proxy_connect_timeout 75s;
|
||||
proxy_buffering off;
|
||||
}
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# HTTPS 配置示例 (使用 Let's Encrypt)
|
||||
# ============================================================
|
||||
|
||||
# server {
|
||||
# listen 443 ssl http2;
|
||||
# server_name your-domain.com;
|
||||
#
|
||||
# ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
|
||||
# ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
|
||||
#
|
||||
# ssl_protocols TLSv1.2 TLSv1.3;
|
||||
# ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
|
||||
# ssl_prefer_server_ciphers off;
|
||||
#
|
||||
# location / {
|
||||
# proxy_pass http://ai_service_admin;
|
||||
# proxy_http_version 1.1;
|
||||
# proxy_set_header Upgrade $http_upgrade;
|
||||
# proxy_set_header Connection 'upgrade';
|
||||
# proxy_set_header Host $host;
|
||||
# proxy_set_header X-Real-IP $remote_addr;
|
||||
# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# proxy_set_header X-Forwarded-Proto $scheme;
|
||||
# proxy_cache_bypass $http_upgrade;
|
||||
# proxy_read_timeout 300s;
|
||||
# proxy_connect_timeout 75s;
|
||||
# proxy_buffering off;
|
||||
# }
|
||||
# }
|
||||
|
||||
# server {
|
||||
# listen 443 ssl http2;
|
||||
# server_name api.your-domain.com;
|
||||
#
|
||||
# ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
|
||||
# ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
|
||||
#
|
||||
# ssl_protocols TLSv1.2 TLSv1.3;
|
||||
# ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
|
||||
# ssl_prefer_server_ciphers off;
|
||||
#
|
||||
# location / {
|
||||
# proxy_pass http://ai_service_backend;
|
||||
# proxy_http_version 1.1;
|
||||
# proxy_set_header Upgrade $http_upgrade;
|
||||
# proxy_set_header Connection 'upgrade';
|
||||
# proxy_set_header Host $host;
|
||||
# proxy_set_header X-Real-IP $remote_addr;
|
||||
# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# proxy_set_header X-Forwarded-Proto $scheme;
|
||||
# proxy_cache_bypass $http_upgrade;
|
||||
# proxy_read_timeout 300s;
|
||||
# proxy_connect_timeout 75s;
|
||||
# proxy_buffering off;
|
||||
# }
|
||||
# }
|
||||
|
||||
# HTTP 重定向到 HTTPS
|
||||
# server {
|
||||
# listen 80;
|
||||
# server_name your-domain.com api.your-domain.com;
|
||||
# return 301 https://$server_name$request_uri;
|
||||
# }
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
services:
|
||||
ai-service:
|
||||
build:
|
||||
context: ./ai-service
|
||||
dockerfile: Dockerfile
|
||||
container_name: ai-service
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8182:8080"
|
||||
environment:
|
||||
- AI_SERVICE_DEBUG=false
|
||||
- AI_SERVICE_LOG_LEVEL=INFO
|
||||
- AI_SERVICE_DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/ai_service
|
||||
- AI_SERVICE_QDRANT_URL=http://qdrant:6333
|
||||
- AI_SERVICE_LLM_PROVIDER=${AI_SERVICE_LLM_PROVIDER:-openai}
|
||||
- AI_SERVICE_LLM_API_KEY=${AI_SERVICE_LLM_API_KEY:-}
|
||||
- AI_SERVICE_LLM_BASE_URL=${AI_SERVICE_LLM_BASE_URL:-https://api.openai.com/v1}
|
||||
- AI_SERVICE_LLM_MODEL=${AI_SERVICE_LLM_MODEL:-gpt-4o-mini}
|
||||
- AI_SERVICE_OLLAMA_BASE_URL=${AI_SERVICE_OLLAMA_BASE_URL:-http://ollama:11434}
|
||||
volumes:
|
||||
- ai_service_config:/app/config
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
qdrant:
|
||||
condition: service_started
|
||||
networks:
|
||||
- ai-network
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/ai/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
ai-service-admin:
|
||||
build:
|
||||
context: ./ai-service-admin
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
VITE_APP_API_KEY: ${VITE_APP_API_KEY:-}
|
||||
VITE_APP_BASE_API: /api
|
||||
container_name: ai-service-admin
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8183:80"
|
||||
depends_on:
|
||||
- ai-service
|
||||
networks:
|
||||
- ai-network
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
container_name: ai-postgres
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_DB=ai_service
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./ai-service/scripts/init_db.sql:/docker-entrypoint-initdb.d/init_db.sql:ro
|
||||
ports:
|
||||
- "5432:5432"
|
||||
networks:
|
||||
- ai-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres -d ai_service"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:latest
|
||||
container_name: ai-qdrant
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "6333:6333"
|
||||
- "6334:6334"
|
||||
volumes:
|
||||
- qdrant_data:/qdrant/storage
|
||||
networks:
|
||||
- ai-network
|
||||
|
||||
ollama:
|
||||
image: ollama/ollama:latest
|
||||
container_name: ai-ollama
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "11434:11434"
|
||||
volumes:
|
||||
- ollama_data:/root/.ollama
|
||||
networks:
|
||||
- ai-network
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
memory: 1G
|
||||
|
||||
networks:
|
||||
ai-network:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
qdrant_data:
|
||||
ollama_data:
|
||||
ai_service_config:
|
||||
Loading…
Reference in New Issue