335 lines
12 KiB
Python
335 lines
12 KiB
Python
"""
|
|
OpenAI-compatible LLM client implementation.
|
|
[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API.
|
|
|
|
Design reference: design.md Section 8.1 - LLMClient interface
|
|
- Uses langchain-openai or official SDK pattern
|
|
- Supports generate and stream_generate
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
|
|
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMException(AIServiceException):
|
|
"""Exception raised when LLM operations fail."""
|
|
|
|
def __init__(self, message: str, details: list[dict] | None = None):
|
|
super().__init__(
|
|
code=ErrorCode.LLM_ERROR,
|
|
message=message,
|
|
status_code=503,
|
|
details=details,
|
|
)
|
|
|
|
|
|
class OpenAIClient(LLMClient):
|
|
"""
|
|
OpenAI-compatible LLM client.
|
|
[AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API.
|
|
|
|
Supports:
|
|
- OpenAI API (official)
|
|
- OpenAI-compatible endpoints (Azure, local models, etc.)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
model: str | None = None,
|
|
default_config: LLMConfig | None = None,
|
|
):
|
|
settings = get_settings()
|
|
self._api_key = api_key or settings.llm_api_key
|
|
self._base_url = (base_url or settings.llm_base_url).rstrip("/")
|
|
self._model = model or settings.llm_model
|
|
self._default_config = default_config or LLMConfig(
|
|
model=self._model,
|
|
max_tokens=settings.llm_max_tokens,
|
|
temperature=settings.llm_temperature,
|
|
timeout_seconds=settings.llm_timeout_seconds,
|
|
max_retries=settings.llm_max_retries,
|
|
)
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient:
|
|
"""Get or create HTTP client."""
|
|
if self._client is None:
|
|
self._client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(timeout_seconds),
|
|
headers={
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
return self._client
|
|
|
|
def _build_request_body(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
config: LLMConfig,
|
|
stream: bool = False,
|
|
**kwargs: Any,
|
|
) -> dict[str, Any]:
|
|
"""Build request body for OpenAI API."""
|
|
body: dict[str, Any] = {
|
|
"model": config.model,
|
|
"messages": messages,
|
|
"max_tokens": config.max_tokens,
|
|
"temperature": config.temperature,
|
|
"top_p": config.top_p,
|
|
"stream": stream,
|
|
}
|
|
body.update(config.extra_params)
|
|
body.update(kwargs)
|
|
return body
|
|
|
|
@retry(
|
|
retry=retry_if_exception_type(httpx.TimeoutException),
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=1, max=10),
|
|
)
|
|
async def generate(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
config: LLMConfig | None = None,
|
|
**kwargs: Any,
|
|
) -> LLMResponse:
|
|
"""
|
|
Generate a non-streaming response.
|
|
[AC-AISVC-02] Returns complete response for ChatResponse.
|
|
|
|
Args:
|
|
messages: List of chat messages with 'role' and 'content'.
|
|
config: Optional LLM configuration overrides.
|
|
**kwargs: Additional provider-specific parameters.
|
|
|
|
Returns:
|
|
LLMResponse with generated content and metadata.
|
|
|
|
Raises:
|
|
LLMException: If generation fails.
|
|
TimeoutException: If request times out.
|
|
"""
|
|
effective_config = config or self._default_config
|
|
client = self._get_client(effective_config.timeout_seconds)
|
|
|
|
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
|
|
|
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
|
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
|
for i, msg in enumerate(messages):
|
|
role = msg.get("role", "unknown")
|
|
content = msg.get("content", "")
|
|
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
|
|
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
|
|
logger.info("[AC-AISVC-02] ======================================")
|
|
|
|
try:
|
|
response = await client.post(
|
|
f"{self._base_url}/chat/completions",
|
|
json=body,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
except httpx.TimeoutException as e:
|
|
logger.error(f"[AC-AISVC-02] LLM request timeout: {e}")
|
|
raise TimeoutException(message=f"LLM request timed out: {e}")
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"[AC-AISVC-02] LLM API error: {e}")
|
|
error_detail = self._parse_error_response(e.response)
|
|
raise LLMException(
|
|
message=f"LLM API error: {error_detail}",
|
|
details=[{"status_code": e.response.status_code, "response": error_detail}],
|
|
)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}")
|
|
raise LLMException(message=f"Failed to parse LLM response: {e}")
|
|
|
|
try:
|
|
choice = data["choices"][0]
|
|
content = choice["message"]["content"]
|
|
usage = data.get("usage", {})
|
|
finish_reason = choice.get("finish_reason", "stop")
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-02] Generated response: "
|
|
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
|
f"finish_reason={finish_reason}"
|
|
)
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
model=data.get("model", effective_config.model),
|
|
usage=usage,
|
|
finish_reason=finish_reason,
|
|
metadata={"raw_response": data},
|
|
)
|
|
|
|
except (KeyError, IndexError) as e:
|
|
logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}")
|
|
raise LLMException(
|
|
message=f"Unexpected LLM response format: {e}",
|
|
details=[{"response": str(data)}],
|
|
)
|
|
|
|
async def stream_generate(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
config: LLMConfig | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncGenerator[LLMStreamChunk, None]:
|
|
"""
|
|
Generate a streaming response.
|
|
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
|
|
|
|
Args:
|
|
messages: List of chat messages with 'role' and 'content'.
|
|
config: Optional LLM configuration overrides.
|
|
**kwargs: Additional provider-specific parameters.
|
|
|
|
Yields:
|
|
LLMStreamChunk with incremental content.
|
|
|
|
Raises:
|
|
LLMException: If generation fails.
|
|
TimeoutException: If request times out.
|
|
"""
|
|
effective_config = config or self._default_config
|
|
client = self._get_client(effective_config.timeout_seconds)
|
|
|
|
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
|
|
|
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
|
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
|
for i, msg in enumerate(messages):
|
|
role = msg.get("role", "unknown")
|
|
content = msg.get("content", "")
|
|
logger.info(f"[AC-AISVC-06] [{i}] role={role}, content_length={len(content)}")
|
|
logger.info(f"[AC-AISVC-06] [{i}] content:\n{content}")
|
|
logger.info("[AC-AISVC-06] ======================================")
|
|
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{self._base_url}/chat/completions",
|
|
json=body,
|
|
) as response:
|
|
response.raise_for_status()
|
|
|
|
async for line in response.aiter_lines():
|
|
if not line or line == "data: [DONE]":
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
json_str = line[6:]
|
|
try:
|
|
chunk_data = json.loads(json_str)
|
|
chunk = self._parse_stream_chunk(chunk_data, effective_config.model)
|
|
if chunk:
|
|
yield chunk
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
|
|
continue
|
|
|
|
except httpx.TimeoutException as e:
|
|
logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}")
|
|
raise TimeoutException(message=f"LLM streaming request timed out: {e}")
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}")
|
|
error_detail = self._parse_error_response(e.response)
|
|
raise LLMException(
|
|
message=f"LLM streaming API error: {error_detail}",
|
|
details=[{"status_code": e.response.status_code, "response": error_detail}],
|
|
)
|
|
|
|
logger.info("[AC-AISVC-06] Streaming generation completed")
|
|
|
|
def _parse_stream_chunk(
|
|
self,
|
|
data: dict[str, Any],
|
|
model: str,
|
|
) -> LLMStreamChunk | None:
|
|
"""Parse a streaming chunk from OpenAI API."""
|
|
try:
|
|
choices = data.get("choices", [])
|
|
if not choices:
|
|
return None
|
|
|
|
delta = choices[0].get("delta", {})
|
|
content = delta.get("content", "")
|
|
finish_reason = choices[0].get("finish_reason")
|
|
|
|
if not content and not finish_reason:
|
|
return None
|
|
|
|
return LLMStreamChunk(
|
|
delta=content,
|
|
model=data.get("model", model),
|
|
finish_reason=finish_reason,
|
|
metadata={"raw_chunk": data},
|
|
)
|
|
|
|
except (KeyError, IndexError) as e:
|
|
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
|
|
return None
|
|
|
|
def _parse_error_response(self, response: httpx.Response) -> str:
|
|
"""Parse error response from API."""
|
|
try:
|
|
data = response.json()
|
|
if "error" in data:
|
|
error = data["error"]
|
|
if isinstance(error, dict):
|
|
return error.get("message", str(error))
|
|
return str(error)
|
|
return response.text
|
|
except Exception:
|
|
return response.text
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
|
|
_llm_client: OpenAIClient | None = None
|
|
|
|
|
|
def get_llm_client() -> OpenAIClient:
|
|
"""Get or create LLM client instance."""
|
|
global _llm_client
|
|
if _llm_client is None:
|
|
_llm_client = OpenAIClient()
|
|
return _llm_client
|
|
|
|
|
|
async def close_llm_client() -> None:
|
|
"""Close the global LLM client."""
|
|
global _llm_client
|
|
if _llm_client:
|
|
await _llm_client.close()
|
|
_llm_client = None
|