ai-robot-core/ai-service/app/api/admin/embedding.py

153 lines
4.4 KiB
Python

"""
Embedding management API endpoints.
[AC-AISVC-38, AC-AISVC-39, AC-AISVC-40, AC-AISVC-41] Embedding model management.
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException
from app.core.exceptions import InvalidRequestException
from app.services.embedding import (
EmbeddingException,
EmbeddingProviderFactory,
get_embedding_config_manager,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/embedding", tags=["Embedding Management"])
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header."""
if not x_tenant_id:
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
return x_tenant_id
@router.get("/providers")
async def list_embedding_providers(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get available embedding providers.
[AC-AISVC-38] Returns all registered providers with config schemas.
"""
providers = []
for name in EmbeddingProviderFactory.get_available_providers():
info = EmbeddingProviderFactory.get_provider_info(name)
providers.append(info)
return {"providers": providers}
@router.get("/config")
async def get_embedding_config(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get current embedding configuration.
[AC-AISVC-39] Returns current provider and config.
"""
manager = get_embedding_config_manager()
return manager.get_full_config()
@router.put("/config")
async def update_embedding_config(
request: dict[str, Any],
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Update embedding configuration.
[AC-AISVC-40, AC-AISVC-31] Updates config with hot-reload support.
"""
provider = request.get("provider")
config = request.get("config", {})
if not provider:
raise InvalidRequestException("provider is required")
if provider not in EmbeddingProviderFactory.get_available_providers():
raise InvalidRequestException(
f"Unknown provider: {provider}. "
f"Available: {EmbeddingProviderFactory.get_available_providers()}"
)
manager = get_embedding_config_manager()
old_config = manager.get_full_config()
old_provider = old_config.get("provider")
old_model = old_config.get("config", {}).get("model", "")
new_model = config.get("model", "")
try:
await manager.update_config(provider, config)
response = {
"success": True,
"message": f"Configuration updated to use {provider}",
}
if old_provider != provider or old_model != new_model:
response["warning"] = (
"嵌入模型已更改。由于不同模型生成的向量不兼容,"
"请删除现有知识库并重新上传文档,以确保检索效果正常。"
)
response["requires_reindex"] = True
logger.warning(
f"[EMBEDDING] Model changed from {old_provider}/{old_model} to {provider}/{new_model}. "
f"Documents need to be re-uploaded."
)
return response
except EmbeddingException as e:
raise InvalidRequestException(str(e))
@router.post("/test")
async def test_embedding(
request: dict[str, Any] | None = None,
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Test embedding connection.
[AC-AISVC-41] Tests provider connectivity and returns dimension info.
"""
request = request or {}
test_text = request.get("test_text", "这是一个测试文本")
config = request.get("config")
provider = request.get("provider")
manager = get_embedding_config_manager()
result = await manager.test_connection(
test_text=test_text,
provider=provider,
config=config,
)
return result
@router.get("/formats")
async def get_supported_document_formats(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get supported document formats for embedding.
Returns list of supported file extensions.
"""
from app.services.document import DocumentParserFactory, get_supported_document_formats
formats = get_supported_document_formats()
parser_info = DocumentParserFactory.get_parser_info()
return {
"formats": formats,
"parsers": parser_info,
}