""" LLM Provider Factory and Configuration Management. [AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management. Design pattern: Factory pattern for pluggable LLM providers. """ import logging from dataclasses import dataclass, field from typing import Any from app.services.llm.base import LLMClient, LLMConfig from app.services.llm.openai_client import OpenAIClient logger = logging.getLogger(__name__) @dataclass class LLMProviderInfo: """Information about an LLM provider.""" name: str display_name: str description: str config_schema: dict[str, Any] LLM_PROVIDERS: dict[str, LLMProviderInfo] = { "openai": LLMProviderInfo( name="openai", display_name="OpenAI", description="OpenAI GPT 系列模型 (GPT-4, GPT-3.5 等)", config_schema={ "api_key": { "type": "string", "description": "API Key", "required": True, "secret": True, }, "base_url": { "type": "string", "description": "API Base URL", "default": "https://api.openai.com/v1", }, "model": { "type": "string", "description": "模型名称", "default": "gpt-4o-mini", }, "max_tokens": { "type": "integer", "description": "最大输出 Token 数", "default": 2048, }, "temperature": { "type": "number", "description": "温度参数 (0-2)", "default": 0.7, }, }, ), "ollama": LLMProviderInfo( name="ollama", display_name="Ollama", description="Ollama 本地模型 (Llama, Qwen 等)", config_schema={ "base_url": { "type": "string", "description": "Ollama API 地址", "default": "http://localhost:11434/v1", }, "model": { "type": "string", "description": "模型名称", "default": "llama3.2", }, "max_tokens": { "type": "integer", "description": "最大输出 Token 数", "default": 2048, }, "temperature": { "type": "number", "description": "温度参数 (0-2)", "default": 0.7, }, }, ), "azure": LLMProviderInfo( name="azure", display_name="Azure OpenAI", description="Azure OpenAI 服务", config_schema={ "api_key": { "type": "string", "description": "API Key", "required": True, "secret": True, }, "base_url": { "type": "string", "description": "Azure Endpoint", "required": True, }, "model": { "type": "string", "description": "部署名称", "required": True, }, "api_version": { "type": "string", "description": "API 版本", "default": "2024-02-15-preview", }, "max_tokens": { "type": "integer", "description": "最大输出 Token 数", "default": 2048, }, "temperature": { "type": "number", "description": "温度参数 (0-2)", "default": 0.7, }, }, ), } class LLMProviderFactory: """ Factory for creating LLM clients. [AC-ASA-14, AC-ASA-15] Dynamic provider creation. """ @classmethod def get_providers(cls) -> list[LLMProviderInfo]: """Get all registered LLM providers.""" return list(LLM_PROVIDERS.values()) @classmethod def get_provider_info(cls, name: str) -> LLMProviderInfo | None: """Get provider info by name.""" return LLM_PROVIDERS.get(name) @classmethod def create_client( cls, provider: str, config: dict[str, Any], ) -> LLMClient: """ Create an LLM client for the specified provider. [AC-ASA-15] Factory method for client creation. Args: provider: Provider name (openai, ollama, azure) config: Provider configuration Returns: LLMClient instance Raises: ValueError: If provider is not supported """ if provider not in LLM_PROVIDERS: raise ValueError(f"Unsupported LLM provider: {provider}") if provider in ("openai", "ollama", "azure"): return OpenAIClient( api_key=config.get("api_key"), base_url=config.get("base_url"), model=config.get("model"), default_config=LLMConfig( model=config.get("model", "gpt-4o-mini"), max_tokens=config.get("max_tokens", 2048), temperature=config.get("temperature", 0.7), ), ) raise ValueError(f"Unsupported LLM provider: {provider}") class LLMConfigManager: """ Manager for LLM configuration. [AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload. """ def __init__(self): self._current_provider: str = "openai" self._current_config: dict[str, Any] = {} self._client: LLMClient | None = None def get_current_config(self) -> dict[str, Any]: """Get current LLM configuration.""" return { "provider": self._current_provider, "config": self._current_config, } async def update_config( self, provider: str, config: dict[str, Any], ) -> bool: """ Update LLM configuration. [AC-ASA-16] Hot-reload configuration. Args: provider: Provider name config: New configuration Returns: True if update successful """ if provider not in LLM_PROVIDERS: raise ValueError(f"Unsupported LLM provider: {provider}") provider_info = LLM_PROVIDERS[provider] validated_config = self._validate_config(provider_info, config) if self._client: await self._client.close() self._client = None self._current_provider = provider self._current_config = validated_config logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}") return True def _validate_config( self, provider_info: LLMProviderInfo, config: dict[str, Any], ) -> dict[str, Any]: """Validate configuration against provider schema.""" validated = {} for key, schema in provider_info.config_schema.items(): if key in config: validated[key] = config[key] elif "default" in schema: validated[key] = schema["default"] elif schema.get("required"): raise ValueError(f"Missing required config: {key}") return validated def get_client(self) -> LLMClient: """Get or create LLM client with current config.""" if self._client is None: self._client = LLMProviderFactory.create_client( self._current_provider, self._current_config, ) return self._client async def test_connection( self, test_prompt: str = "你好,请简单介绍一下自己。", provider: str | None = None, config: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Test LLM connection. [AC-ASA-17, AC-ASA-18] Connection testing. Args: test_prompt: Test prompt to send provider: Optional provider to test (uses current if not specified) config: Optional config to test (uses current if not specified) Returns: Test result with success status, response, and metrics """ import time test_provider = provider or self._current_provider test_config = config or self._current_config if test_provider not in LLM_PROVIDERS: return { "success": False, "error": f"Unsupported provider: {test_provider}", } try: client = LLMProviderFactory.create_client(test_provider, test_config) start_time = time.time() response = await client.generate( messages=[{"role": "user", "content": test_prompt}], ) latency_ms = (time.time() - start_time) * 1000 await client.close() return { "success": True, "response": response.content, "latency_ms": round(latency_ms, 2), "prompt_tokens": response.usage.get("prompt_tokens", 0), "completion_tokens": response.usage.get("completion_tokens", 0), "total_tokens": response.usage.get("total_tokens", 0), "model": response.model, "message": f"连接成功,模型: {response.model}", } except Exception as e: logger.error(f"[AC-ASA-18] LLM test failed: {e}") return { "success": False, "error": str(e), "message": f"连接失败: {str(e)}", } async def close(self) -> None: """Close the current client.""" if self._client: await self._client.close() self._client = None _llm_config_manager: LLMConfigManager | None = None def get_llm_config_manager() -> LLMConfigManager: """Get or create LLM config manager instance.""" global _llm_config_manager if _llm_config_manager is None: _llm_config_manager = LLMConfigManager() return _llm_config_manager