feat: 添加嵌入配置持久化及模型切换警告 [AC-AISVC-50]

- 添加嵌入配置持久化到config/embedding_config.json
- 服务启动时自动加载保存的配置
- 切换模型时前端显示警告提示需要重新上传文档
- 修复OptimizedRetriever缓存问题,每次检索获取最新配置
- 清理调试用的Python临时文件
- 更新.gitignore忽略config目录
This commit is contained in:
MerCry 2026-02-26 18:01:03 +08:00
parent fd04ed2cef
commit d660c19ab9
8 changed files with 91 additions and 24 deletions

1
.gitignore vendored
View File

@ -162,5 +162,6 @@ cython_debug/
# Project specific # Project specific
ai-service/uploads/ ai-service/uploads/
ai-service/config/
*.local *.local

View File

@ -105,7 +105,10 @@ docker exec -it ai-ollama ollama pull toshk0/nomic-embed-text-v2-moe:Q6_K
- **Matryoshka 截断**`true` - **Matryoshka 截断**`true`
3. 点击 **保存配置** 3. 点击 **保存配置**
> **注意**: 使用 Nomic Embed (优化版) provider 可启用完整的 RAG 优化功能任务前缀、Matryoshka 多向量、两阶段检索。 > **注意**:
> - 使用 Nomic Embed (优化版) provider 可启用完整的 RAG 优化功能任务前缀、Matryoshka 多向量、两阶段检索。
> - 嵌入模型配置会持久化保存到 `ai-service/config/embedding_config.json`,服务重启后自动加载。
> - **重要**: 切换嵌入模型后,需要删除现有知识库并重新上传文档,因为不同模型生成的向量不兼容。
#### 6. 验证服务 #### 6. 验证服务

View File

@ -74,7 +74,8 @@ export const useEmbeddingStore = defineStore('embedding', () => {
provider: currentConfig.value.provider, provider: currentConfig.value.provider,
config: currentConfig.value.config config: currentConfig.value.config
} }
await saveConfig(updateData) const response = await saveConfig(updateData)
return response
} catch (error) { } catch (error) {
console.error('Failed to save config:', error) console.error('Failed to save config:', error)
throw error throw error

View File

@ -169,8 +169,19 @@ const handleSave = async () => {
saving.value = true saving.value = true
try { try {
await embeddingStore.saveCurrentConfig() const response: any = await embeddingStore.saveCurrentConfig()
ElMessage.success('配置保存成功') ElMessage.success('配置保存成功')
if (response?.warning || response?.requires_reindex) {
ElMessageBox.alert(
response.warning || '嵌入模型已更改,请重新上传文档以确保检索效果正常。',
'重要提示',
{
confirmButtonText: '我知道了',
type: 'warning',
}
)
}
} catch (error) { } catch (error) {
ElMessage.error('配置保存失败') ElMessage.error('配置保存失败')
} finally { } finally {

View File

@ -14,7 +14,7 @@ export default defineConfig({
port: 3000, port: 3000,
proxy: { proxy: {
'/api': { '/api': {
target: 'http://localhost:8088', target: 'http://localhost:8000',
changeOrigin: true, changeOrigin: true,
rewrite: (path) => path.replace(/^\/api/, ''), rewrite: (path) => path.replace(/^\/api/, ''),
}, },

View File

@ -78,12 +78,32 @@ async def update_embedding_config(
manager = get_embedding_config_manager() manager = get_embedding_config_manager()
old_config = manager.get_full_config()
old_provider = old_config.get("provider")
old_model = old_config.get("config", {}).get("model", "")
new_model = config.get("model", "")
try: try:
await manager.update_config(provider, config) await manager.update_config(provider, config)
return {
response = {
"success": True, "success": True,
"message": f"Configuration updated to use {provider}", "message": f"Configuration updated to use {provider}",
} }
if old_provider != provider or old_model != new_model:
response["warning"] = (
"嵌入模型已更改。由于不同模型生成的向量不兼容,"
"请删除现有知识库并重新上传文档,以确保检索效果正常。"
)
response["requires_reindex"] = True
logger.warning(
f"[EMBEDDING] Model changed from {old_provider}/{old_model} to {provider}/{new_model}. "
f"Documents need to be re-uploaded."
)
return response
except EmbeddingException as e: except EmbeddingException as e:
raise InvalidRequestException(str(e)) raise InvalidRequestException(str(e))

View File

@ -7,7 +7,9 @@ Design reference: progress.md Section 7.1 - Architecture
- EmbeddingConfigManager: manages configuration with hot-reload support - EmbeddingConfigManager: manages configuration with hot-reload support
""" """
import json
import logging import logging
from pathlib import Path
from typing import Any, Type from typing import Any, Type
from app.services.embedding.base import EmbeddingException, EmbeddingProvider from app.services.embedding.base import EmbeddingException, EmbeddingProvider
@ -17,6 +19,8 @@ from app.services.embedding.nomic_provider import NomicEmbeddingProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
class EmbeddingProviderFactory: class EmbeddingProviderFactory:
""" """
@ -152,18 +156,47 @@ class EmbeddingProviderFactory:
class EmbeddingConfigManager: class EmbeddingConfigManager:
""" """
Manager for embedding configuration. Manager for embedding configuration.
[AC-AISVC-31] Supports hot-reload of configuration. [AC-AISVC-31] Supports hot-reload of configuration with persistence.
""" """
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None): def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
self._provider_name = default_provider self._default_provider = default_provider
self._config = default_config or { self._default_config = default_config or {
"base_url": "http://localhost:11434", "base_url": "http://localhost:11434",
"model": "nomic-embed-text", "model": "nomic-embed-text",
"dimension": 768, "dimension": 768,
} }
self._provider_name = default_provider
self._config = self._default_config.copy()
self._provider: EmbeddingProvider | None = None self._provider: EmbeddingProvider | None = None
self._load_from_file()
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if EMBEDDING_CONFIG_FILE.exists():
with open(EMBEDDING_CONFIG_FILE, 'r', encoding='utf-8') as f:
saved = json.load(f)
self._provider_name = saved.get("provider", self._default_provider)
self._config = saved.get("config", self._default_config.copy())
logger.info(f"Loaded embedding config from file: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to load embedding config from file: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
EMBEDDING_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(EMBEDDING_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump({
"provider": self._provider_name,
"config": self._config,
}, f, indent=2, ensure_ascii=False)
logger.info(f"Saved embedding config to file: provider={self._provider_name}")
except Exception as e:
logger.error(f"Failed to save embedding config to file: {e}")
def get_provider_name(self) -> str: def get_provider_name(self) -> str:
"""Get current provider name.""" """Get current provider name."""
return self._provider_name return self._provider_name
@ -201,7 +234,7 @@ class EmbeddingConfigManager:
) -> bool: ) -> bool:
""" """
Update embedding configuration. Update embedding configuration.
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload. [AC-AISVC-31, AC-AISVC-40] Supports hot-reload with persistence.
Args: Args:
provider: New provider name provider: New provider name
@ -229,6 +262,8 @@ class EmbeddingConfigManager:
self._config = config self._config = config
self._provider = new_provider_instance self._provider = new_provider_instance
self._save_to_file()
logger.info(f"Updated embedding config: provider={provider}") logger.info(f"Updated embedding config: provider={provider}")
return True return True

View File

@ -138,7 +138,6 @@ class OptimizedRetriever(BaseRetriever):
def __init__( def __init__(
self, self,
qdrant_client: QdrantClient | None = None, qdrant_client: QdrantClient | None = None,
embedding_provider: NomicEmbeddingProvider | None = None,
top_k: int | None = None, top_k: int | None = None,
score_threshold: float | None = None, score_threshold: float | None = None,
min_hits: int | None = None, min_hits: int | None = None,
@ -148,7 +147,6 @@ class OptimizedRetriever(BaseRetriever):
rrf_k: int | None = None, rrf_k: int | None = None,
): ):
self._qdrant_client = qdrant_client self._qdrant_client = qdrant_client
self._embedding_provider = embedding_provider
self._top_k = top_k or settings.rag_top_k self._top_k = top_k or settings.rag_top_k
self._score_threshold = score_threshold or settings.rag_score_threshold self._score_threshold = score_threshold or settings.rag_score_threshold
self._min_hits = min_hits or settings.rag_min_hits self._min_hits = min_hits or settings.rag_min_hits
@ -164,19 +162,17 @@ class OptimizedRetriever(BaseRetriever):
return self._qdrant_client return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider: async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None: from app.services.embedding.factory import get_embedding_config_manager
from app.services.embedding.factory import get_embedding_config_manager manager = get_embedding_config_manager()
manager = get_embedding_config_manager() provider = await manager.get_provider()
provider = await manager.get_provider() if isinstance(provider, NomicEmbeddingProvider):
if isinstance(provider, NomicEmbeddingProvider): return provider
self._embedding_provider = provider else:
else: return NomicEmbeddingProvider(
self._embedding_provider = NomicEmbeddingProvider( base_url=settings.ollama_base_url,
base_url=settings.ollama_base_url, model=settings.ollama_embedding_model,
model=settings.ollama_embedding_model, dimension=settings.qdrant_vector_size,
dimension=settings.qdrant_vector_size, )
)
return self._embedding_provider
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult: async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
""" """