""" Embedding management API endpoints. [AC-AISVC-38, AC-AISVC-39, AC-AISVC-40, AC-AISVC-41] Embedding model management. """ import logging from typing import Any from fastapi import APIRouter, Depends, Header, HTTPException from app.core.exceptions import InvalidRequestException from app.services.embedding import ( EmbeddingException, EmbeddingProviderFactory, get_embedding_config_manager, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/embedding", tags=["Embedding Management"]) def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str: """Extract tenant ID from header.""" if not x_tenant_id: raise HTTPException(status_code=400, detail="X-Tenant-Id header is required") return x_tenant_id @router.get("/providers") async def list_embedding_providers( tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """ Get available embedding providers. [AC-AISVC-38] Returns all registered providers with config schemas. """ providers = [] for name in EmbeddingProviderFactory.get_available_providers(): info = EmbeddingProviderFactory.get_provider_info(name) providers.append(info) return {"providers": providers} @router.get("/config") async def get_embedding_config( tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """ Get current embedding configuration. [AC-AISVC-39] Returns current provider and config. """ manager = get_embedding_config_manager() return manager.get_full_config() @router.put("/config") async def update_embedding_config( request: dict[str, Any], tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """ Update embedding configuration. [AC-AISVC-40, AC-AISVC-31] Updates config with hot-reload support. """ provider = request.get("provider") config = request.get("config", {}) if not provider: raise InvalidRequestException("provider is required") if provider not in EmbeddingProviderFactory.get_available_providers(): raise InvalidRequestException( f"Unknown provider: {provider}. " f"Available: {EmbeddingProviderFactory.get_available_providers()}" ) manager = get_embedding_config_manager() old_config = manager.get_full_config() old_provider = old_config.get("provider") old_model = old_config.get("config", {}).get("model", "") new_model = config.get("model", "") try: await manager.update_config(provider, config) response = { "success": True, "message": f"Configuration updated to use {provider}", } if old_provider != provider or old_model != new_model: response["warning"] = ( "嵌入模型已更改。由于不同模型生成的向量不兼容," "请删除现有知识库并重新上传文档,以确保检索效果正常。" ) response["requires_reindex"] = True logger.warning( f"[EMBEDDING] Model changed from {old_provider}/{old_model} to {provider}/{new_model}. " f"Documents need to be re-uploaded." ) return response except EmbeddingException as e: raise InvalidRequestException(str(e)) @router.post("/test") async def test_embedding( request: dict[str, Any] | None = None, tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """ Test embedding connection. [AC-AISVC-41] Tests provider connectivity and returns dimension info. """ request = request or {} test_text = request.get("test_text", "这是一个测试文本") config = request.get("config") provider = request.get("provider") manager = get_embedding_config_manager() result = await manager.test_connection( test_text=test_text, provider=provider, config=config, ) return result @router.get("/formats") async def get_supported_document_formats( tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """ Get supported document formats for embedding. Returns list of supported file extensions. """ from app.services.document import DocumentParserFactory, get_supported_document_formats formats = get_supported_document_formats() parser_info = DocumentParserFactory.get_parser_info() return { "formats": formats, "parsers": parser_info, }