ai-robot-core/ai-service/app/services/embedding/factory.py

368 lines
12 KiB
Python
Raw Normal View History

"""
Embedding provider factory and configuration manager.
[AC-AISVC-30, AC-AISVC-31] Factory pattern for dynamic provider loading.
Design reference: progress.md Section 7.1 - Architecture
- EmbeddingProviderFactory: creates providers based on config
- EmbeddingConfigManager: manages configuration with hot-reload support
"""
import json
import logging
from pathlib import Path
from typing import Any
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
logger = logging.getLogger(__name__)
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
class EmbeddingProviderFactory:
"""
Factory for creating embedding providers.
[AC-AISVC-30] Supports dynamic loading based on configuration.
"""
_providers: dict[str, type[EmbeddingProvider]] = {
"ollama": OllamaEmbeddingProvider,
"openai": OpenAIEmbeddingProvider,
"nomic": NomicEmbeddingProvider,
}
@classmethod
def register_provider(cls, name: str, provider_class: type[EmbeddingProvider]) -> None:
"""
Register a new embedding provider.
[AC-AISVC-30] Allows runtime registration of providers.
"""
cls._providers[name] = provider_class
logger.info(f"Registered embedding provider: {name}")
@classmethod
def get_available_providers(cls) -> list[str]:
"""
Get list of available provider names.
[AC-AISVC-38] Returns registered provider identifiers.
"""
return list(cls._providers.keys())
@classmethod
def get_provider_info(cls, name: str) -> dict[str, Any]:
"""
Get provider information including config schema.
[AC-AISVC-38] Returns provider metadata.
"""
if name not in cls._providers:
raise EmbeddingException(
f"Unknown provider: {name}",
provider="factory"
)
provider_class = cls._providers[name]
temp_instance = provider_class.__new__(provider_class)
display_names = {
"ollama": "Ollama 本地模型",
"openai": "OpenAI Embedding",
"nomic": "Nomic Embed (优化版)",
}
descriptions = {
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
"openai": "使用 OpenAI 官方 Embedding API支持 text-embedding-3 系列模型",
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断专为RAG优化",
}
raw_schema = temp_instance.get_config_schema()
properties = {}
required = []
for key, field in raw_schema.items():
properties[key] = {
"type": field.get("type", "string"),
"title": field.get("title", key),
"description": field.get("description", ""),
"default": field.get("default"),
}
if field.get("enum"):
properties[key]["enum"] = field.get("enum")
if field.get("minimum") is not None:
properties[key]["minimum"] = field.get("minimum")
if field.get("maximum") is not None:
properties[key]["maximum"] = field.get("maximum")
if field.get("required"):
required.append(key)
config_schema = {
"type": "object",
"properties": properties,
}
if required:
config_schema["required"] = required
return {
"name": name,
"display_name": display_names.get(name, name),
"description": descriptions.get(name, ""),
"config_schema": config_schema,
}
@classmethod
def create_provider(
cls,
name: str,
config: dict[str, Any],
) -> EmbeddingProvider:
"""
Create an embedding provider instance.
[AC-AISVC-30] Creates provider based on configuration.
Args:
name: Provider identifier (e.g., "ollama", "openai")
config: Provider-specific configuration
Returns:
Configured EmbeddingProvider instance
Raises:
EmbeddingException: If provider is unknown or configuration is invalid
"""
if name not in cls._providers:
raise EmbeddingException(
f"Unknown embedding provider: {name}. "
f"Available: {cls.get_available_providers()}",
provider="factory"
)
provider_class = cls._providers[name]
try:
instance = provider_class(**config)
logger.info(f"Created embedding provider: {name}")
return instance
except Exception as e:
raise EmbeddingException(
f"Failed to create provider '{name}': {e}",
provider="factory",
details={"config": config}
)
class EmbeddingConfigManager:
"""
Manager for embedding configuration.
[AC-AISVC-31] Supports hot-reload of configuration with persistence.
"""
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
self._default_provider = default_provider
self._default_config = default_config or {
"base_url": "http://localhost:11434",
"model": "nomic-embed-text",
"dimension": 768,
}
self._provider_name = default_provider
self._config = self._default_config.copy()
self._provider: EmbeddingProvider | None = None
self._load_from_file()
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if EMBEDDING_CONFIG_FILE.exists():
with open(EMBEDDING_CONFIG_FILE, encoding='utf-8') as f:
saved = json.load(f)
self._provider_name = saved.get("provider", self._default_provider)
self._config = saved.get("config", self._default_config.copy())
logger.info(f"Loaded embedding config from file: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to load embedding config from file: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
EMBEDDING_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(EMBEDDING_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump({
"provider": self._provider_name,
"config": self._config,
}, f, indent=2, ensure_ascii=False)
logger.info(f"Saved embedding config to file: provider={self._provider_name}")
except Exception as e:
logger.error(f"Failed to save embedding config to file: {e}")
def get_provider_name(self) -> str:
"""Get current provider name."""
return self._provider_name
def get_config(self) -> dict[str, Any]:
"""Get current configuration."""
return self._config.copy()
def get_full_config(self) -> dict[str, Any]:
"""
Get full configuration including provider name.
[AC-AISVC-39] Returns complete configuration for API response.
"""
return {
"provider": self._provider_name,
"config": self._config.copy(),
}
async def get_provider(self) -> EmbeddingProvider:
"""
Get or create the embedding provider.
[AC-AISVC-29] Returns configured provider instance.
"""
if self._provider is None:
self._provider = EmbeddingProviderFactory.create_provider(
self._provider_name,
self._config
)
return self._provider
async def update_config(
self,
provider: str,
config: dict[str, Any],
) -> bool:
"""
Update embedding configuration.
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload with persistence.
Args:
provider: New provider name
config: New provider configuration
Returns:
True if update was successful
Raises:
EmbeddingException: If configuration is invalid
"""
old_provider = self._provider_name
old_config = self._config.copy()
try:
new_provider_instance = EmbeddingProviderFactory.create_provider(
provider,
config
)
if self._provider:
await self._provider.close()
self._provider_name = provider
self._config = config
self._provider = new_provider_instance
self._save_to_file()
logger.info(f"Updated embedding config: provider={provider}")
return True
except Exception as e:
self._provider_name = old_provider
self._config = old_config
raise EmbeddingException(
f"Failed to update config: {e}",
provider="config_manager",
details={"provider": provider, "config": config}
)
async def test_connection(
self,
test_text: str = "这是一个测试文本",
provider: str | None = None,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Test embedding connection.
[AC-AISVC-41] Tests provider connectivity.
Args:
test_text: Text to embed for testing
provider: Provider to test (uses current if None)
config: Config to test (uses current if None)
Returns:
Dict with test results including success, dimension, latency
"""
import time
test_provider_name = provider or self._provider_name
test_config = config or self._config
try:
test_provider = EmbeddingProviderFactory.create_provider(
test_provider_name,
test_config
)
start_time = time.perf_counter()
embedding = await test_provider.embed(test_text)
latency_ms = (time.perf_counter() - start_time) * 1000
await test_provider.close()
return {
"success": True,
"dimension": len(embedding),
"latency_ms": latency_ms,
"message": f"连接成功,向量维度: {len(embedding)}",
}
except Exception as e:
return {
"success": False,
"dimension": 0,
"latency_ms": 0,
"error": str(e),
"message": f"连接失败: {e}",
}
async def close(self) -> None:
"""Close the current provider."""
if self._provider:
await self._provider.close()
self._provider = None
_embedding_config_manager: EmbeddingConfigManager | None = None
def get_embedding_config_manager() -> EmbeddingConfigManager:
"""
Get the global embedding config manager.
[AC-AISVC-31] Singleton pattern for configuration management.
"""
global _embedding_config_manager
if _embedding_config_manager is None:
from app.core.config import get_settings
settings = get_settings()
_embedding_config_manager = EmbeddingConfigManager(
default_provider="nomic",
default_config={
"base_url": settings.ollama_base_url,
"model": settings.ollama_embedding_model,
"dimension": settings.qdrant_vector_size,
}
)
return _embedding_config_manager
async def get_embedding_provider() -> EmbeddingProvider:
"""
Get the current embedding provider.
[AC-AISVC-29] Convenience function for getting provider.
"""
manager = get_embedding_config_manager()
return await manager.get_provider()