ai-robot-core/ai-service/app/services/llm/factory.py

333 lines
10 KiB
Python
Raw Normal View History

"""
LLM Provider Factory and Configuration Management.
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management.
Design pattern: Factory pattern for pluggable LLM providers.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from app.services.llm.base import LLMClient, LLMConfig
from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__)
@dataclass
class LLMProviderInfo:
"""Information about an LLM provider."""
name: str
display_name: str
description: str
config_schema: dict[str, Any]
LLM_PROVIDERS: dict[str, LLMProviderInfo] = {
"openai": LLMProviderInfo(
name="openai",
display_name="OpenAI",
description="OpenAI GPT 系列模型 (GPT-4, GPT-3.5 等)",
config_schema={
"api_key": {
"type": "string",
"description": "API Key",
"required": True,
"secret": True,
},
"base_url": {
"type": "string",
"description": "API Base URL",
"default": "https://api.openai.com/v1",
},
"model": {
"type": "string",
"description": "模型名称",
"default": "gpt-4o-mini",
},
"max_tokens": {
"type": "integer",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"description": "温度参数 (0-2)",
"default": 0.7,
},
},
),
"ollama": LLMProviderInfo(
name="ollama",
display_name="Ollama",
description="Ollama 本地模型 (Llama, Qwen 等)",
config_schema={
"base_url": {
"type": "string",
"description": "Ollama API 地址",
"default": "http://localhost:11434/v1",
},
"model": {
"type": "string",
"description": "模型名称",
"default": "llama3.2",
},
"max_tokens": {
"type": "integer",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"description": "温度参数 (0-2)",
"default": 0.7,
},
},
),
"azure": LLMProviderInfo(
name="azure",
display_name="Azure OpenAI",
description="Azure OpenAI 服务",
config_schema={
"api_key": {
"type": "string",
"description": "API Key",
"required": True,
"secret": True,
},
"base_url": {
"type": "string",
"description": "Azure Endpoint",
"required": True,
},
"model": {
"type": "string",
"description": "部署名称",
"required": True,
},
"api_version": {
"type": "string",
"description": "API 版本",
"default": "2024-02-15-preview",
},
"max_tokens": {
"type": "integer",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"description": "温度参数 (0-2)",
"default": 0.7,
},
},
),
}
class LLMProviderFactory:
"""
Factory for creating LLM clients.
[AC-ASA-14, AC-ASA-15] Dynamic provider creation.
"""
@classmethod
def get_providers(cls) -> list[LLMProviderInfo]:
"""Get all registered LLM providers."""
return list(LLM_PROVIDERS.values())
@classmethod
def get_provider_info(cls, name: str) -> LLMProviderInfo | None:
"""Get provider info by name."""
return LLM_PROVIDERS.get(name)
@classmethod
def create_client(
cls,
provider: str,
config: dict[str, Any],
) -> LLMClient:
"""
Create an LLM client for the specified provider.
[AC-ASA-15] Factory method for client creation.
Args:
provider: Provider name (openai, ollama, azure)
config: Provider configuration
Returns:
LLMClient instance
Raises:
ValueError: If provider is not supported
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
if provider in ("openai", "ollama", "azure"):
return OpenAIClient(
api_key=config.get("api_key"),
base_url=config.get("base_url"),
model=config.get("model"),
default_config=LLMConfig(
model=config.get("model", "gpt-4o-mini"),
max_tokens=config.get("max_tokens", 2048),
temperature=config.get("temperature", 0.7),
),
)
raise ValueError(f"Unsupported LLM provider: {provider}")
class LLMConfigManager:
"""
Manager for LLM configuration.
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload.
"""
def __init__(self):
self._current_provider: str = "openai"
self._current_config: dict[str, Any] = {}
self._client: LLMClient | None = None
def get_current_config(self) -> dict[str, Any]:
"""Get current LLM configuration."""
return {
"provider": self._current_provider,
"config": self._current_config,
}
async def update_config(
self,
provider: str,
config: dict[str, Any],
) -> bool:
"""
Update LLM configuration.
[AC-ASA-16] Hot-reload configuration.
Args:
provider: Provider name
config: New configuration
Returns:
True if update successful
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
if self._client:
await self._client.close()
self._client = None
self._current_provider = provider
self._current_config = validated_config
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
return True
def _validate_config(
self,
provider_info: LLMProviderInfo,
config: dict[str, Any],
) -> dict[str, Any]:
"""Validate configuration against provider schema."""
validated = {}
for key, schema in provider_info.config_schema.items():
if key in config:
validated[key] = config[key]
elif "default" in schema:
validated[key] = schema["default"]
elif schema.get("required"):
raise ValueError(f"Missing required config: {key}")
return validated
def get_client(self) -> LLMClient:
"""Get or create LLM client with current config."""
if self._client is None:
self._client = LLMProviderFactory.create_client(
self._current_provider,
self._current_config,
)
return self._client
async def test_connection(
self,
test_prompt: str = "你好,请简单介绍一下自己。",
provider: str | None = None,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Test LLM connection.
[AC-ASA-17, AC-ASA-18] Connection testing.
Args:
test_prompt: Test prompt to send
provider: Optional provider to test (uses current if not specified)
config: Optional config to test (uses current if not specified)
Returns:
Test result with success status, response, and metrics
"""
import time
test_provider = provider or self._current_provider
test_config = config or self._current_config
if test_provider not in LLM_PROVIDERS:
return {
"success": False,
"error": f"Unsupported provider: {test_provider}",
}
try:
client = LLMProviderFactory.create_client(test_provider, test_config)
start_time = time.time()
response = await client.generate(
messages=[{"role": "user", "content": test_prompt}],
)
latency_ms = (time.time() - start_time) * 1000
await client.close()
return {
"success": True,
"response": response.content,
"latency_ms": round(latency_ms, 2),
"prompt_tokens": response.usage.get("prompt_tokens", 0),
"completion_tokens": response.usage.get("completion_tokens", 0),
"total_tokens": response.usage.get("total_tokens", 0),
"model": response.model,
"message": f"连接成功,模型: {response.model}",
}
except Exception as e:
logger.error(f"[AC-ASA-18] LLM test failed: {e}")
return {
"success": False,
"error": str(e),
"message": f"连接失败: {str(e)}",
}
async def close(self) -> None:
"""Close the current client."""
if self._client:
await self._client.close()
self._client = None
_llm_config_manager: LLMConfigManager | None = None
def get_llm_config_manager() -> LLMConfigManager:
"""Get or create LLM config manager instance."""
global _llm_config_manager
if _llm_config_manager is None:
_llm_config_manager = LLMConfigManager()
return _llm_config_manager