feat(ai-service): implement LLM Adapter for T3.1 [AC-AISVC-02, AC-AISVC-06]
- Add LLMClient abstract base class with generate/stream_generate interfaces - Implement OpenAIClient with httpx for OpenAI-compatible API calls - Add retry logic with tenacity for timeout handling - Support both non-streaming and streaming generation - Add comprehensive unit tests for LLM Adapter - Fix entities.py JSON column type for SQLModel compatibility
This commit is contained in:
parent
cc70ffeca6
commit
0a167d69f0
|
|
@ -0,0 +1,26 @@
|
|||
# AI Service
|
||||
|
||||
Python AI Service for intelligent chat with RAG support.
|
||||
|
||||
## Features
|
||||
|
||||
- Multi-tenant isolation via X-Tenant-Id header
|
||||
- SSE streaming support via Accept: text/event-stream
|
||||
- RAG-powered responses with confidence scoring
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8080
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
- `POST /ai/chat` - Generate AI reply
|
||||
- `GET /ai/health` - Health check
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
"""
|
||||
AI Service - Python AI Middle Platform
|
||||
[AC-AISVC-01] FastAPI-based AI chat service with multi-tenant support.
|
||||
"""
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
API module for AI Service.
|
||||
"""
|
||||
|
||||
from app.api.chat import router as chat_router
|
||||
from app.api.health import router as health_router
|
||||
|
||||
__all__ = ["chat_router", "health_router"]
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
Chat endpoint for AI Service.
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Main chat endpoint with streaming/non-streaming modes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.core.middleware import get_response_mode, is_sse_request
|
||||
from app.core.sse import create_error_event
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
||||
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["AI Chat"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ai/chat",
|
||||
operation_id="generateReply",
|
||||
summary="Generate AI reply",
|
||||
description="""
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Generate AI reply based on user message.
|
||||
|
||||
Response mode is determined by Accept header:
|
||||
- Accept: text/event-stream -> SSE streaming response
|
||||
- Other -> JSON response
|
||||
""",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Success - JSON or SSE stream",
|
||||
"content": {
|
||||
"application/json": {"schema": {"$ref": "#/components/schemas/ChatResponse"}},
|
||||
"text/event-stream": {"schema": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
400: {"description": "Invalid request", "model": ErrorResponse},
|
||||
500: {"description": "Internal error", "model": ErrorResponse},
|
||||
503: {"description": "Service unavailable", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def generate_reply(
|
||||
request: Request,
|
||||
chat_request: ChatRequest,
|
||||
accept: Annotated[str | None, Header()] = None,
|
||||
orchestrator: OrchestratorService = Depends(get_orchestrator_service),
|
||||
) -> Any:
|
||||
"""
|
||||
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
|
||||
|
||||
Based on Accept header:
|
||||
- text/event-stream: Returns SSE stream with message/final/error events
|
||||
- Other: Returns JSON ChatResponse
|
||||
"""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
raise MissingTenantIdException()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-06] Processing chat request: tenant={tenant_id}, "
|
||||
f"session={chat_request.session_id}, mode={get_response_mode(request)}"
|
||||
)
|
||||
|
||||
if is_sse_request(request):
|
||||
return await _handle_streaming_request(tenant_id, chat_request, orchestrator)
|
||||
else:
|
||||
return await _handle_json_request(tenant_id, chat_request, orchestrator)
|
||||
|
||||
|
||||
async def _handle_json_request(
|
||||
tenant_id: str,
|
||||
chat_request: ChatRequest,
|
||||
orchestrator: OrchestratorService,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-AISVC-02] Handle non-streaming JSON request.
|
||||
Returns ChatResponse with reply, confidence, shouldTransfer.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-02] Processing JSON request for tenant={tenant_id}")
|
||||
|
||||
try:
|
||||
response = await orchestrator.generate(tenant_id, chat_request)
|
||||
return JSONResponse(
|
||||
content=response.model_dump(exclude_none=True, by_alias=True),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-04] Error generating response: {e}")
|
||||
from app.core.exceptions import AIServiceException, ErrorCode
|
||||
if isinstance(e, AIServiceException):
|
||||
raise e
|
||||
from app.core.exceptions import AIServiceException
|
||||
raise AIServiceException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_streaming_request(
|
||||
tenant_id: str,
|
||||
chat_request: ChatRequest,
|
||||
orchestrator: OrchestratorService,
|
||||
) -> EventSourceResponse:
|
||||
"""
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
|
||||
Yields message events followed by final or error event.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
async for event in orchestrator.generate_stream(tenant_id, chat_request):
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
|
||||
yield create_error_event(
|
||||
code="STREAMING_ERROR",
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
return EventSourceResponse(event_generator(), ping=15)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
"""
|
||||
Health check endpoint.
|
||||
[AC-AISVC-20] Health check for service monitoring.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ai/health",
|
||||
operation_id="healthCheck",
|
||||
summary="Health check",
|
||||
description="[AC-AISVC-20] Check if AI service is healthy",
|
||||
responses={
|
||||
200: {"description": "Service is healthy"},
|
||||
503: {"description": "Service is unhealthy"},
|
||||
},
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
[AC-AISVC-20] Health check endpoint.
|
||||
Returns 200 with status if healthy, 503 if not.
|
||||
"""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"status": "healthy"},
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Core module - Configuration, dependencies, and utilities.
|
||||
[AC-AISVC-01, AC-AISVC-10, AC-AISVC-11] Core infrastructure components.
|
||||
"""
|
||||
|
||||
from app.core.config import Settings, get_settings
|
||||
from app.core.database import async_session_maker, get_session, init_db, close_db
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
"get_settings",
|
||||
"async_session_maker",
|
||||
"get_session",
|
||||
"init_db",
|
||||
"close_db",
|
||||
"QdrantClient",
|
||||
"get_qdrant_client",
|
||||
]
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Configuration management for AI Service.
|
||||
[AC-AISVC-01] Centralized configuration with environment variable support.
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="AI_SERVICE_", env_file=".env", extra="ignore")
|
||||
|
||||
app_name: str = "AI Service"
|
||||
app_version: str = "0.1.0"
|
||||
debug: bool = False
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8080
|
||||
|
||||
request_timeout_seconds: int = 20
|
||||
sse_ping_interval_seconds: int = 15
|
||||
|
||||
log_level: str = "INFO"
|
||||
|
||||
llm_provider: str = "openai"
|
||||
llm_api_key: str = ""
|
||||
llm_base_url: str = "https://api.openai.com/v1"
|
||||
llm_model: str = "gpt-4o-mini"
|
||||
llm_max_tokens: int = 2048
|
||||
llm_temperature: float = 0.7
|
||||
llm_timeout_seconds: int = 30
|
||||
llm_max_retries: int = 3
|
||||
|
||||
database_url: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/ai_service"
|
||||
database_pool_size: int = 10
|
||||
database_max_overflow: int = 20
|
||||
|
||||
qdrant_url: str = "http://localhost:6333"
|
||||
qdrant_collection_prefix: str = "kb_"
|
||||
qdrant_vector_size: int = 1536
|
||||
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.7
|
||||
rag_min_hits: int = 1
|
||||
rag_max_evidence_tokens: int = 2000
|
||||
|
||||
confidence_threshold_low: float = 0.5
|
||||
max_history_tokens: int = 4000
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Database client for AI Service.
|
||||
[AC-AISVC-11] PostgreSQL database with SQLModel for multi-tenant data isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
echo=settings.debug,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""
|
||||
[AC-AISVC-11] Initialize database tables.
|
||||
Creates all tables defined in SQLModel metadata.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
logger.info("[AC-AISVC-11] Database tables initialized")
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""
|
||||
Close database connections.
|
||||
"""
|
||||
await engine.dispose()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
[AC-AISVC-11] Dependency injection for database session.
|
||||
Ensures proper session lifecycle management.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Exception handling for AI Service.
|
||||
[AC-AISVC-03, AC-AISVC-04, AC-AISVC-05] Structured error responses.
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.models import ErrorCode, ErrorResponse
|
||||
|
||||
|
||||
class AIServiceException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
code: ErrorCode,
|
||||
message: str,
|
||||
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
details: list[dict] | None = None,
|
||||
):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
self.details = details
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class MissingTenantIdException(AIServiceException):
|
||||
def __init__(self, message: str = "Missing required header: X-Tenant-Id"):
|
||||
super().__init__(
|
||||
code=ErrorCode.MISSING_TENANT_ID,
|
||||
message=message,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestException(AIServiceException):
|
||||
def __init__(self, message: str, details: list[dict] | None = None):
|
||||
super().__init__(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
message=message,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableException(AIServiceException):
|
||||
def __init__(self, message: str = "Service temporarily unavailable"):
|
||||
super().__init__(
|
||||
code=ErrorCode.SERVICE_UNAVAILABLE,
|
||||
message=message,
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
|
||||
class TimeoutException(AIServiceException):
|
||||
def __init__(self, message: str = "Request timeout"):
|
||||
super().__init__(
|
||||
code=ErrorCode.TIMEOUT,
|
||||
message=message,
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
|
||||
async def ai_service_exception_handler(request: Request, exc: AIServiceException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(
|
||||
code=exc.code.value,
|
||||
message=exc.message,
|
||||
details=exc.details,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
if exc.status_code == status.HTTP_400_BAD_REQUEST:
|
||||
code = ErrorCode.INVALID_REQUEST
|
||||
elif exc.status_code == status.HTTP_503_SERVICE_UNAVAILABLE:
|
||||
code = ErrorCode.SERVICE_UNAVAILABLE
|
||||
else:
|
||||
code = ErrorCode.INTERNAL_ERROR
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(
|
||||
code=code.value,
|
||||
message=exc.detail or "An error occurred",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.INTERNAL_ERROR.value,
|
||||
message="An unexpected error occurred",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Middleware for AI Service.
|
||||
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
|
||||
from app.core.tenant import clear_tenant_context, set_tenant_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TENANT_ID_HEADER = "X-Tenant-Id"
|
||||
ACCEPT_HEADER = "Accept"
|
||||
SSE_CONTENT_TYPE = "text/event-stream"
|
||||
|
||||
|
||||
class TenantContextMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
||||
Injects tenant context into request state for downstream processing.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
clear_tenant_context()
|
||||
|
||||
if request.url.path == "/ai/health":
|
||||
return await call_next(request)
|
||||
|
||||
tenant_id = request.headers.get(TENANT_ID_HEADER)
|
||||
|
||||
if not tenant_id or not tenant_id.strip():
|
||||
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.MISSING_TENANT_ID.value,
|
||||
message="Missing required header: X-Tenant-Id",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
set_tenant_context(tenant_id.strip())
|
||||
request.state.tenant_id = tenant_id.strip()
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}")
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
clear_tenant_context()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def is_sse_request(request: Request) -> bool:
|
||||
"""
|
||||
[AC-AISVC-06] Check if the request expects SSE streaming response.
|
||||
Based on Accept header: text/event-stream indicates SSE mode.
|
||||
"""
|
||||
accept_header = request.headers.get(ACCEPT_HEADER, "")
|
||||
return SSE_CONTENT_TYPE in accept_header
|
||||
|
||||
|
||||
def get_response_mode(request: Request) -> str:
|
||||
"""
|
||||
[AC-AISVC-06] Determine response mode based on Accept header.
|
||||
Returns 'streaming' for SSE, 'json' for regular JSON response.
|
||||
"""
|
||||
return "streaming" if is_sse_request(request) else "json"
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Qdrant client for AI Service.
|
||||
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, PointStruct, VectorParams
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
"""
|
||||
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
|
||||
Collection naming: kb_{tenantId} for tenant isolation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._client: AsyncQdrantClient | None = None
|
||||
self._collection_prefix = settings.qdrant_collection_prefix
|
||||
self._vector_size = settings.qdrant_vector_size
|
||||
|
||||
async def get_client(self) -> AsyncQdrantClient:
|
||||
"""Get or create Qdrant client instance."""
|
||||
if self._client is None:
|
||||
self._client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}")
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Qdrant client connection."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.info("Qdrant client connection closed")
|
||||
|
||||
def get_collection_name(self, tenant_id: str) -> str:
|
||||
"""
|
||||
[AC-AISVC-10] Get collection name for a tenant.
|
||||
Naming convention: kb_{tenantId}
|
||||
"""
|
||||
return f"{self._collection_prefix}{tenant_id}"
|
||||
|
||||
async def ensure_collection_exists(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Ensure collection exists for tenant.
|
||||
Note: MVP uses pre-provisioned collections, this is for development/testing.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
exists = any(c.name == collection_name for c in collections.collections)
|
||||
|
||||
if not exists:
|
||||
await client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self._vector_size,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
points: list[PointStruct],
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Upsert vectors into tenant's collection.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
await client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
|
||||
return False
|
||||
|
||||
async def search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query_vector: list[float],
|
||||
limit: int = 5,
|
||||
score_threshold: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||
Returns results with score >= score_threshold if specified.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
hits = [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}"
|
||||
)
|
||||
return hits
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error searching vectors: {e}")
|
||||
return []
|
||||
|
||||
async def delete_collection(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Delete tenant's collection.
|
||||
Used when tenant is removed.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
await client.delete_collection(collection_name=collection_name)
|
||||
logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error deleting collection: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_qdrant_client: QdrantClient | None = None
|
||||
|
||||
|
||||
async def get_qdrant_client() -> QdrantClient:
|
||||
"""Get or create Qdrant client instance."""
|
||||
global _qdrant_client
|
||||
if _qdrant_client is None:
|
||||
_qdrant_client = QdrantClient()
|
||||
return _qdrant_client
|
||||
|
||||
|
||||
async def close_qdrant_client() -> None:
|
||||
"""Close Qdrant client connection."""
|
||||
global _qdrant_client
|
||||
if _qdrant_client:
|
||||
await _qdrant_client.close()
|
||||
_qdrant_client = None
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""
|
||||
SSE utilities for AI Service.
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSEState(str, Enum):
|
||||
INIT = "INIT"
|
||||
STREAMING = "STREAMING"
|
||||
FINAL_SENT = "FINAL_SENT"
|
||||
ERROR_SENT = "ERROR_SENT"
|
||||
CLOSED = "CLOSED"
|
||||
|
||||
|
||||
class SSEStateMachine:
|
||||
"""
|
||||
[AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence.
|
||||
State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state = SSEState.INIT
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def state(self) -> SSEState:
|
||||
return self._state
|
||||
|
||||
async def transition_to_streaming(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._state == SSEState.INIT:
|
||||
self._state = SSEState.STREAMING
|
||||
logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def transition_to_final(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._state == SSEState.STREAMING:
|
||||
self._state = SSEState.FINAL_SENT
|
||||
logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def transition_to_error(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._state in (SSEState.INIT, SSEState.STREAMING):
|
||||
self._state = SSEState.ERROR_SENT
|
||||
logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._lock:
|
||||
self._state = SSEState.CLOSED
|
||||
logger.debug("SSE state transition: -> CLOSED")
|
||||
|
||||
def can_send_message(self) -> bool:
|
||||
return self._state == SSEState.STREAMING
|
||||
|
||||
|
||||
def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent:
|
||||
"""Format data as SSE event."""
|
||||
return ServerSentEvent(
|
||||
event=event_type.value,
|
||||
data=json.dumps(data, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
def create_message_event(delta: str) -> ServerSentEvent:
|
||||
"""[AC-AISVC-07] Create a message event with incremental content."""
|
||||
event_data = SSEMessageEvent(delta=delta)
|
||||
return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump())
|
||||
|
||||
|
||||
def create_final_event(
|
||||
reply: str,
|
||||
confidence: float,
|
||||
should_transfer: bool,
|
||||
transfer_reason: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ServerSentEvent:
|
||||
"""[AC-AISVC-08] Create a final event with complete response."""
|
||||
event_data = SSEFinalEvent(
|
||||
reply=reply,
|
||||
confidence=confidence,
|
||||
should_transfer=should_transfer,
|
||||
transfer_reason=transfer_reason,
|
||||
metadata=metadata,
|
||||
)
|
||||
return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def create_error_event(
|
||||
code: str,
|
||||
message: str,
|
||||
details: list[dict[str, Any]] | None = None,
|
||||
) -> ServerSentEvent:
|
||||
"""[AC-AISVC-09] Create an error event."""
|
||||
event_data = SSEErrorEvent(
|
||||
code=code,
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
[AC-AISVC-06] Generate ping comments for SSE keep-alive.
|
||||
Sends ': ping' as comment lines (not events) to keep connection alive.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(interval_seconds)
|
||||
yield ": ping\n\n"
|
||||
|
||||
|
||||
class SSEResponseBuilder:
|
||||
"""
|
||||
Builder for SSE response with proper event sequencing and ping keep-alive.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state_machine = SSEStateMachine()
|
||||
self._settings = get_settings()
|
||||
|
||||
async def build_response(
|
||||
self,
|
||||
content_generator: AsyncGenerator[ServerSentEvent, None],
|
||||
) -> EventSourceResponse:
|
||||
"""
|
||||
Build SSE response with ping keep-alive mechanism.
|
||||
[AC-AISVC-06] Implements ping keep-alive to prevent connection timeout.
|
||||
"""
|
||||
|
||||
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
|
||||
await self._state_machine.transition_to_streaming()
|
||||
try:
|
||||
async for event in content_generator:
|
||||
if self._state_machine.can_send_message():
|
||||
yield event
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}")
|
||||
if await self._state_machine.transition_to_error():
|
||||
yield create_error_event(
|
||||
code="STREAMING_ERROR",
|
||||
message=str(e),
|
||||
)
|
||||
finally:
|
||||
await self._state_machine.close()
|
||||
|
||||
return EventSourceResponse(
|
||||
event_generator(),
|
||||
ping=self._settings.sse_ping_interval_seconds,
|
||||
)
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""
|
||||
Tenant context management.
|
||||
[AC-AISVC-10, AC-AISVC-12] Multi-tenant isolation via X-Tenant-Id header.
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
tenant_context: ContextVar["TenantContext | None"] = ContextVar("tenant_context", default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TenantContext:
|
||||
tenant_id: str
|
||||
|
||||
|
||||
def set_tenant_context(tenant_id: str) -> None:
|
||||
tenant_context.set(TenantContext(tenant_id=tenant_id))
|
||||
|
||||
|
||||
def get_tenant_context() -> TenantContext | None:
|
||||
return tenant_context.get()
|
||||
|
||||
|
||||
def get_tenant_id() -> str | None:
|
||||
ctx = get_tenant_context()
|
||||
return ctx.tenant_id if ctx else None
|
||||
|
||||
|
||||
def clear_tenant_context() -> None:
|
||||
tenant_context.set(None)
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
"""
|
||||
Main FastAPI application for AI Service.
|
||||
[AC-AISVC-01] Entry point with middleware and exception handlers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import HTTPException, RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api import chat_router, health_router
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import close_db, init_db
|
||||
from app.core.exceptions import (
|
||||
AIServiceException,
|
||||
ErrorCode,
|
||||
ErrorResponse,
|
||||
ai_service_exception_handler,
|
||||
generic_exception_handler,
|
||||
http_exception_handler,
|
||||
)
|
||||
from app.core.middleware import TenantContextMiddleware
|
||||
from app.core.qdrant_client import close_qdrant_client
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-11] Application lifespan manager.
|
||||
Handles startup and shutdown of database and external connections.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}")
|
||||
|
||||
try:
|
||||
await init_db()
|
||||
logger.info("[AC-AISVC-11] Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}")
|
||||
|
||||
yield
|
||||
|
||||
await close_db()
|
||||
await close_qdrant_client()
|
||||
logger.info(f"Shutting down {settings.app_name}")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="""
|
||||
Python AI Service for intelligent chat with RAG support.
|
||||
|
||||
## Features
|
||||
- Multi-tenant isolation via X-Tenant-Id header
|
||||
- SSE streaming support via Accept: text/event-stream
|
||||
- RAG-powered responses with confidence scoring
|
||||
|
||||
## Response Modes
|
||||
- **JSON**: Default response mode (Accept: application/json or no Accept header)
|
||||
- **SSE Streaming**: Set Accept: text/event-stream for streaming responses
|
||||
""",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(TenantContextMiddleware)
|
||||
|
||||
app.add_exception_handler(AIServiceException, ai_service_exception_handler)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(Exception, generic_exception_handler)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""
|
||||
[AC-AISVC-03] Handle request validation errors with structured response.
|
||||
"""
|
||||
logger.warning(f"[AC-AISVC-03] Request validation error: {exc.errors()}")
|
||||
error_response = ErrorResponse(
|
||||
code=ErrorCode.INVALID_REQUEST.value,
|
||||
message="Request validation failed",
|
||||
details=[{"loc": list(err["loc"]), "msg": err["msg"], "type": err["type"]} for err in exc.errors()],
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=error_response.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
app.include_router(health_router)
|
||||
app.include_router(chat_router)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
)
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
"""
|
||||
Data models for AI Service.
|
||||
[AC-AISVC-02] Request/Response models aligned with OpenAPI contract.
|
||||
[AC-AISVC-13] Entity models for database persistence.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
WECHAT = "wechat"
|
||||
DOUYIN = "douyin"
|
||||
JD = "jd"
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role = Field(..., description="Message role: user or assistant")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str = Field(..., alias="sessionId", description="Session ID for conversation tracking")
|
||||
current_message: str = Field(..., alias="currentMessage", description="Current user message")
|
||||
channel_type: ChannelType = Field(..., alias="channelType", description="Channel type: wechat, douyin, jd")
|
||||
history: list[ChatMessage] | None = Field(default=None, description="Optional conversation history")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str = Field(..., description="AI generated reply content")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score between 0.0 and 1.0")
|
||||
should_transfer: bool = Field(..., alias="shouldTransfer", description="Whether to suggest transfer to human agent")
|
||||
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Reason for transfer suggestion")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
MISSING_TENANT_ID = "MISSING_TENANT_ID"
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
LLM_ERROR = "LLM_ERROR"
|
||||
RETRIEVAL_ERROR = "RETRIEVAL_ERROR"
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
code: str = Field(..., description="Error code")
|
||||
message: str = Field(..., description="Error message")
|
||||
details: list[dict[str, Any]] | None = Field(default=None, description="Detailed error information")
|
||||
|
||||
|
||||
class SSEEventType(str, Enum):
|
||||
MESSAGE = "message"
|
||||
FINAL = "final"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SSEMessageEvent(BaseModel):
|
||||
delta: str = Field(..., description="Incremental text content")
|
||||
|
||||
|
||||
class SSEFinalEvent(BaseModel):
|
||||
reply: str = Field(..., description="Complete AI reply")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
||||
should_transfer: bool = Field(..., alias="shouldTransfer", description="Transfer suggestion")
|
||||
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Transfer reason")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class SSEErrorEvent(BaseModel):
|
||||
code: str = Field(..., description="Error code")
|
||||
message: str = Field(..., description="Error message")
|
||||
details: list[dict[str, Any]] | None = Field(default=None, description="Error details")
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Memory layer entities for AI Service.
|
||||
[AC-AISVC-13] SQLModel entities for chat sessions and messages with tenant isolation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Column, JSON
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class ChatSession(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-13] Chat session entity with tenant isolation.
|
||||
Primary key: (tenant_id, session_id) composite unique constraint.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_sessions"
|
||||
__table_args__ = (
|
||||
Index("ix_chat_sessions_tenant_session", "tenant_id", "session_id", unique=True),
|
||||
Index("ix_chat_sessions_tenant_id", "tenant_id"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
session_id: str = Field(..., description="Session ID for conversation tracking")
|
||||
channel_type: str | None = Field(default=None, description="Channel type: wechat, douyin, jd")
|
||||
metadata_: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("metadata", JSON, nullable=True),
|
||||
description="Session metadata"
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class ChatMessage(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-13] Chat message entity with tenant isolation.
|
||||
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_messages"
|
||||
__table_args__ = (
|
||||
Index("ix_chat_messages_tenant_session", "tenant_id", "session_id"),
|
||||
Index("ix_chat_messages_tenant_session_created", "tenant_id", "session_id", "created_at"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
|
||||
role: str = Field(..., description="Message role: user or assistant")
|
||||
content: str = Field(..., description="Message content")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Message creation time")
|
||||
|
||||
|
||||
class ChatSessionCreate(SQLModel):
|
||||
"""Schema for creating a new chat session."""
|
||||
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
channel_type: str | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageCreate(SQLModel):
|
||||
"""Schema for creating a new chat message."""
|
||||
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
role: str
|
||||
content: str
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Services module for AI Service.
|
||||
[AC-AISVC-13, AC-AISVC-16] Core services for memory and retrieval.
|
||||
"""
|
||||
|
||||
from app.services.memory import MemoryService
|
||||
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
||||
|
||||
__all__ = ["MemoryService", "OrchestratorService", "get_orchestrator_service"]
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""
|
||||
LLM Adapter module for AI Service.
|
||||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
|
||||
"""
|
||||
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
__all__ = [
|
||||
"LLMClient",
|
||||
"LLMConfig",
|
||||
"LLMResponse",
|
||||
"LLMStreamChunk",
|
||||
"OpenAIClient",
|
||||
]
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Base LLM client interface.
|
||||
[AC-AISVC-02, AC-AISVC-06] Abstract interface for LLM providers.
|
||||
|
||||
Design reference: design.md Section 8.1 - LLMClient interface
|
||||
- generate(prompt, params) -> text
|
||||
- stream_generate(prompt, params) -> iterator[delta]
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""
|
||||
Configuration for LLM client.
|
||||
[AC-AISVC-02] Supports configurable model parameters.
|
||||
"""
|
||||
model: str = "gpt-4o-mini"
|
||||
max_tokens: int = 2048
|
||||
temperature: float = 0.7
|
||||
top_p: float = 1.0
|
||||
timeout_seconds: int = 30
|
||||
max_retries: int = 3
|
||||
extra_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""
|
||||
Response from LLM generation.
|
||||
[AC-AISVC-02] Contains generated content and metadata.
|
||||
"""
|
||||
content: str
|
||||
model: str
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
finish_reason: str = "stop"
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMStreamChunk:
|
||||
"""
|
||||
Streaming chunk from LLM.
|
||||
[AC-AISVC-06, AC-AISVC-07] Incremental output for SSE streaming.
|
||||
"""
|
||||
delta: str
|
||||
model: str
|
||||
finish_reason: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""
|
||||
Abstract base class for LLM clients.
|
||||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for different LLM providers.
|
||||
|
||||
Design reference: design.md Section 8.2 - Plugin points
|
||||
- OpenAICompatibleClient / LocalModelClient can be swapped
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
[AC-AISVC-02] Returns complete response for ChatResponse.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
LLMStreamChunk with incremental content.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the client and release resources."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,319 @@
|
|||
"""
|
||||
OpenAI-compatible LLM client implementation.
|
||||
[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API.
|
||||
|
||||
Design reference: design.md Section 8.1 - LLMClient interface
|
||||
- Uses langchain-openai or official SDK pattern
|
||||
- Supports generate and stream_generate
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.exceptions import AIServiceException, ErrorCode, ServiceUnavailableException, TimeoutException
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMException(AIServiceException):
|
||||
"""Exception raised when LLM operations fail."""
|
||||
|
||||
def __init__(self, message: str, details: list[dict] | None = None):
|
||||
super().__init__(
|
||||
code=ErrorCode.LLM_ERROR,
|
||||
message=message,
|
||||
status_code=503,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
"""
|
||||
OpenAI-compatible LLM client.
|
||||
[AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API.
|
||||
|
||||
Supports:
|
||||
- OpenAI API (official)
|
||||
- OpenAI-compatible endpoints (Azure, local models, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
default_config: LLMConfig | None = None,
|
||||
):
|
||||
settings = get_settings()
|
||||
self._api_key = api_key or settings.llm_api_key
|
||||
self._base_url = (base_url or settings.llm_base_url).rstrip("/")
|
||||
self._model = model or settings.llm_model
|
||||
self._default_config = default_config or LLMConfig(
|
||||
model=self._model,
|
||||
max_tokens=settings.llm_max_tokens,
|
||||
temperature=settings.llm_temperature,
|
||||
timeout_seconds=settings.llm_timeout_seconds,
|
||||
max_retries=settings.llm_max_retries,
|
||||
)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout_seconds),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _build_request_body(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request body for OpenAI API."""
|
||||
body: dict[str, Any] = {
|
||||
"model": config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"top_p": config.top_p,
|
||||
"stream": stream,
|
||||
}
|
||||
body.update(config.extra_params)
|
||||
body.update(kwargs)
|
||||
return body
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.TimeoutException),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
)
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
[AC-AISVC-02] Returns complete response for ChatResponse.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
TimeoutException: If request times out.
|
||||
"""
|
||||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
||||
|
||||
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self._base_url}/chat/completions",
|
||||
json=body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"[AC-AISVC-02] LLM request timeout: {e}")
|
||||
raise TimeoutException(message=f"LLM request timed out: {e}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"[AC-AISVC-02] LLM API error: {e}")
|
||||
error_detail = self._parse_error_response(e.response)
|
||||
raise LLMException(
|
||||
message=f"LLM API error: {error_detail}",
|
||||
details=[{"status_code": e.response.status_code, "response": error_detail}],
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}")
|
||||
raise LLMException(message=f"Failed to parse LLM response: {e}")
|
||||
|
||||
try:
|
||||
choice = data["choices"][0]
|
||||
content = choice["message"]["content"]
|
||||
usage = data.get("usage", {})
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Generated response: "
|
||||
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
||||
f"finish_reason={finish_reason}"
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=data.get("model", effective_config.model),
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
metadata={"raw_response": data},
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}")
|
||||
raise LLMException(
|
||||
message=f"Unexpected LLM response format: {e}",
|
||||
details=[{"response": str(data)}],
|
||||
)
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
LLMStreamChunk with incremental content.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
TimeoutException: If request times out.
|
||||
"""
|
||||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
||||
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self._base_url}/chat/completions",
|
||||
json=body,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
json_str = line[6:]
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
chunk = self._parse_stream_chunk(chunk_data, effective_config.model)
|
||||
if chunk:
|
||||
yield chunk
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
|
||||
continue
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}")
|
||||
raise TimeoutException(message=f"LLM streaming request timed out: {e}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}")
|
||||
error_detail = self._parse_error_response(e.response)
|
||||
raise LLMException(
|
||||
message=f"LLM streaming API error: {error_detail}",
|
||||
details=[{"status_code": e.response.status_code, "response": error_detail}],
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Streaming generation completed")
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
model: str,
|
||||
) -> LLMStreamChunk | None:
|
||||
"""Parse a streaming chunk from OpenAI API."""
|
||||
try:
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
|
||||
if not content and not finish_reason:
|
||||
return None
|
||||
|
||||
return LLMStreamChunk(
|
||||
delta=content,
|
||||
model=data.get("model", model),
|
||||
finish_reason=finish_reason,
|
||||
metadata={"raw_chunk": data},
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
|
||||
return None
|
||||
|
||||
def _parse_error_response(self, response: httpx.Response) -> str:
|
||||
"""Parse error response from API."""
|
||||
try:
|
||||
data = response.json()
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
if isinstance(error, dict):
|
||||
return error.get("message", str(error))
|
||||
return str(error)
|
||||
return response.text
|
||||
except Exception:
|
||||
return response.text
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
_llm_client: OpenAIClient | None = None
|
||||
|
||||
|
||||
def get_llm_client() -> OpenAIClient:
|
||||
"""Get or create LLM client instance."""
|
||||
global _llm_client
|
||||
if _llm_client is None:
|
||||
_llm_client = OpenAIClient()
|
||||
return _llm_client
|
||||
|
||||
|
||||
async def close_llm_client() -> None:
|
||||
"""Close the global LLM client."""
|
||||
global _llm_client
|
||||
if _llm_client:
|
||||
await _llm_client.close()
|
||||
_llm_client = None
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""
|
||||
Memory service for AI Service.
|
||||
[AC-AISVC-13] Session-based memory management with tenant isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col
|
||||
|
||||
from app.models.entities import ChatMessage, ChatMessageCreate, ChatSession, ChatSessionCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
[AC-AISVC-13] Memory service for session-based conversation history.
|
||||
All operations are scoped by (tenant_id, session_id) for multi-tenant isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def get_or_create_session(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
channel_type: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
[AC-AISVC-13] Get existing session or create a new one.
|
||||
Ensures tenant isolation by querying with tenant_id.
|
||||
"""
|
||||
stmt = select(ChatSession).where(
|
||||
ChatSession.tenant_id == tenant_id,
|
||||
ChatSession.session_id == session_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
existing_session = result.scalar_one_or_none()
|
||||
|
||||
if existing_session:
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return existing_session
|
||||
|
||||
new_session = ChatSession(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
channel_type=channel_type,
|
||||
metadata_=metadata,
|
||||
)
|
||||
self._session.add(new_session)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return new_session
|
||||
|
||||
async def load_history(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[ChatMessage]:
|
||||
"""
|
||||
[AC-AISVC-13] Load conversation history for a session.
|
||||
All queries are filtered by tenant_id to ensure isolation.
|
||||
"""
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).asc())
|
||||
)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def append_message(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
[AC-AISVC-13] Append a message to the session history.
|
||||
Message is scoped by tenant_id for isolation.
|
||||
"""
|
||||
message = ChatMessage(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
content=content,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}"
|
||||
)
|
||||
return message
|
||||
|
||||
async def append_messages(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
messages: list[dict[str, str]],
|
||||
) -> list[ChatMessage]:
|
||||
"""
|
||||
[AC-AISVC-13] Append multiple messages to the session history.
|
||||
Used for batch insertion of conversation turns.
|
||||
"""
|
||||
chat_messages = []
|
||||
for msg in messages:
|
||||
message = ChatMessage(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
role=msg["role"],
|
||||
content=msg["content"],
|
||||
)
|
||||
self._session.add(message)
|
||||
chat_messages.append(message)
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return chat_messages
|
||||
|
||||
async def clear_history(self, tenant_id: str, session_id: str) -> int:
|
||||
"""
|
||||
[AC-AISVC-13] Clear all messages for a session.
|
||||
Only affects messages within the tenant's scope.
|
||||
"""
|
||||
stmt = select(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for message in messages:
|
||||
await self._session.delete(message)
|
||||
count += 1
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return count
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
Orchestrator service for AI Service.
|
||||
[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sse_starlette.sse import ServerSentEvent
|
||||
|
||||
from app.models import ChatRequest, ChatResponse
|
||||
from app.core.sse import create_final_event, create_message_event, SSEStateMachine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrchestratorService:
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation.
|
||||
Coordinates memory, retrieval, and LLM components.
|
||||
"""
|
||||
|
||||
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
[AC-AISVC-02] Returns ChatResponse with reply, confidence, shouldTransfer.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-01] Generating response for tenant={tenant_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
|
||||
reply = f"Received your message: {request.current_message}"
|
||||
return ChatResponse(
|
||||
reply=reply,
|
||||
confidence=0.85,
|
||||
should_transfer=False,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, tenant_id: str, request: ChatRequest
|
||||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
|
||||
state_machine = SSEStateMachine()
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
try:
|
||||
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
|
||||
full_reply = ""
|
||||
|
||||
for part in reply_parts:
|
||||
if state_machine.can_send_message():
|
||||
full_reply += part
|
||||
yield create_message_event(delta=part)
|
||||
await self._simulate_llm_delay()
|
||||
|
||||
if await state_machine.transition_to_final():
|
||||
yield create_final_event(
|
||||
reply=full_reply,
|
||||
confidence=0.85,
|
||||
should_transfer=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
||||
if await state_machine.transition_to_error():
|
||||
from app.core.sse import create_error_event
|
||||
yield create_error_event(
|
||||
code="GENERATION_ERROR",
|
||||
message=str(e),
|
||||
)
|
||||
finally:
|
||||
await state_machine.close()
|
||||
|
||||
async def _simulate_llm_delay(self) -> None:
|
||||
"""Simulate LLM processing delay for demo purposes."""
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
_orchestrator_service: OrchestratorService | None = None
|
||||
|
||||
|
||||
def get_orchestrator_service() -> OrchestratorService:
|
||||
"""Get or create orchestrator service instance."""
|
||||
global _orchestrator_service
|
||||
if _orchestrator_service is None:
|
||||
_orchestrator_service = OrchestratorService()
|
||||
return _orchestrator_service
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""
|
||||
Retrieval module for AI Service.
|
||||
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
|
||||
"""
|
||||
|
||||
from app.services.retrieval.base import (
|
||||
BaseRetriever,
|
||||
RetrievalContext,
|
||||
RetrievalHit,
|
||||
RetrievalResult,
|
||||
)
|
||||
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
|
||||
|
||||
__all__ = [
|
||||
"BaseRetriever",
|
||||
"RetrievalContext",
|
||||
"RetrievalHit",
|
||||
"RetrievalResult",
|
||||
"VectorRetriever",
|
||||
"get_vector_retriever",
|
||||
]
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Retrieval layer for AI Service.
|
||||
[AC-AISVC-16] Abstract base class for retrievers with plugin point support.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalContext:
|
||||
"""
|
||||
[AC-AISVC-16] Context for retrieval operations.
|
||||
Contains all necessary information for retrieval plugins.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
query: str
|
||||
session_id: str | None = None
|
||||
channel_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalHit:
|
||||
"""
|
||||
[AC-AISVC-16] Single retrieval result hit.
|
||||
Unified structure for all retriever types.
|
||||
"""
|
||||
|
||||
text: str
|
||||
score: float
|
||||
source: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""
|
||||
[AC-AISVC-16] Result from retrieval operation.
|
||||
Contains hits and optional diagnostics.
|
||||
"""
|
||||
|
||||
hits: list[RetrievalHit] = field(default_factory=list)
|
||||
diagnostics: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if no hits were found."""
|
||||
return len(self.hits) == 0
|
||||
|
||||
@property
|
||||
def max_score(self) -> float:
|
||||
"""Get the maximum score among hits."""
|
||||
if not self.hits:
|
||||
return 0.0
|
||||
return max(hit.score for hit in self.hits)
|
||||
|
||||
@property
|
||||
def hit_count(self) -> int:
|
||||
"""Get the number of hits."""
|
||||
return len(self.hits)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""
|
||||
[AC-AISVC-16] Abstract base class for retrievers.
|
||||
Provides plugin point for different retrieval strategies (Vector, Graph, Hybrid).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||
"""
|
||||
[AC-AISVC-16] Retrieve relevant documents for the given context.
|
||||
|
||||
Args:
|
||||
ctx: Retrieval context containing tenant_id, query, and optional metadata.
|
||||
|
||||
Returns:
|
||||
RetrievalResult with hits and optional diagnostics.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if the retriever is healthy and ready to serve requests.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
Vector retriever for AI Service.
|
||||
[AC-AISVC-16, AC-AISVC-17] Qdrant-based vector retrieval with score threshold filtering.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.retrieval.base import (
|
||||
BaseRetriever,
|
||||
RetrievalContext,
|
||||
RetrievalHit,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class VectorRetriever(BaseRetriever):
|
||||
"""
|
||||
[AC-AISVC-16, AC-AISVC-17] Vector-based retriever using Qdrant.
|
||||
Supports score threshold filtering and tenant isolation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
top_k: int | None = None,
|
||||
score_threshold: float | None = None,
|
||||
min_hits: int | None = None,
|
||||
):
|
||||
self._qdrant_client = qdrant_client
|
||||
self._top_k = top_k or settings.rag_top_k
|
||||
self._score_threshold = score_threshold or settings.rag_score_threshold
|
||||
self._min_hits = min_hits or settings.rag_min_hits
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
"""Get Qdrant client instance."""
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||
"""
|
||||
[AC-AISVC-16, AC-AISVC-17] Retrieve documents from vector store.
|
||||
|
||||
Steps:
|
||||
1. Generate embedding for query (placeholder - requires embedding provider)
|
||||
2. Search in tenant's collection
|
||||
3. Filter by score threshold
|
||||
4. Return structured result
|
||||
|
||||
Args:
|
||||
ctx: Retrieval context with tenant_id and query.
|
||||
|
||||
Returns:
|
||||
RetrievalResult with filtered hits.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Starting vector retrieval for tenant={ctx.tenant_id}, query={ctx.query[:50]}..."
|
||||
)
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
||||
query_vector = await self._get_embedding(ctx.query)
|
||||
|
||||
hits = await client.search(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query_vector=query_vector,
|
||||
limit=self._top_k,
|
||||
score_threshold=self._score_threshold,
|
||||
)
|
||||
|
||||
retrieval_hits = [
|
||||
RetrievalHit(
|
||||
text=hit.get("payload", {}).get("text", ""),
|
||||
score=hit.get("score", 0.0),
|
||||
source=hit.get("payload", {}).get("source", "vector"),
|
||||
metadata=hit.get("payload", {}),
|
||||
)
|
||||
for hit in hits
|
||||
if hit.get("score", 0.0) >= self._score_threshold
|
||||
]
|
||||
|
||||
is_insufficient = len(retrieval_hits) < self._min_hits
|
||||
|
||||
diagnostics = {
|
||||
"query_length": len(ctx.query),
|
||||
"top_k": self._top_k,
|
||||
"score_threshold": self._score_threshold,
|
||||
"min_hits": self._min_hits,
|
||||
"total_candidates": len(hits),
|
||||
"filtered_hits": len(retrieval_hits),
|
||||
"is_insufficient": is_insufficient,
|
||||
"max_score": max((h.score for h in retrieval_hits), default=0.0),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-17] Retrieval complete: {len(retrieval_hits)} hits, "
|
||||
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
hits=retrieval_hits,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-16] Retrieval error: {e}")
|
||||
return RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e), "is_insufficient": True},
|
||||
)
|
||||
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for text.
|
||||
[AC-AISVC-16] Placeholder for embedding generation.
|
||||
|
||||
TODO: Integrate with actual embedding provider (OpenAI, local model, etc.)
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
hash_obj = hashlib.sha256(text.encode())
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes) * 8, settings.qdrant_vector_size)):
|
||||
byte_idx = i // 8
|
||||
bit_idx = i % 8
|
||||
if byte_idx < len(hash_bytes):
|
||||
val = (hash_bytes[byte_idx] >> bit_idx) & 1
|
||||
embedding.append(float(val))
|
||||
else:
|
||||
embedding.append(0.0)
|
||||
|
||||
while len(embedding) < settings.qdrant_vector_size:
|
||||
embedding.append(0.0)
|
||||
|
||||
return embedding[: settings.qdrant_vector_size]
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
[AC-AISVC-16] Check if Qdrant connection is healthy.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
qdrant = await client.get_client()
|
||||
await qdrant.get_collections()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-16] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_vector_retriever: VectorRetriever | None = None
|
||||
|
||||
|
||||
async def get_vector_retriever() -> VectorRetriever:
|
||||
"""Get or create VectorRetriever instance."""
|
||||
global _vector_retriever
|
||||
if _vector_retriever is None:
|
||||
_vector_retriever = VectorRetriever()
|
||||
return _vector_retriever
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
[project]
|
||||
name = "ai-service"
|
||||
version = "0.1.0"
|
||||
description = "Python AI Service for intelligent chat with RAG support"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"fastapi>=0.109.0",
|
||||
"uvicorn[standard]>=0.27.0",
|
||||
"pydantic>=2.5.0",
|
||||
"pydantic-settings>=2.1.0",
|
||||
"sse-starlette>=2.0.0",
|
||||
"httpx>=0.26.0",
|
||||
"tenacity>=8.2.0",
|
||||
"sqlmodel>=0.0.14",
|
||||
"asyncpg>=0.29.0",
|
||||
"qdrant-client>=1.7.0",
|
||||
"tiktoken>=0.5.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.4.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"httpx>=0.26.0",
|
||||
"ruff>=0.1.0",
|
||||
"mypy>=1.8.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["app"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Tests package for AI Service.
|
||||
"""
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Pytest configuration for AI Service tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return "asyncio"
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
"""
|
||||
Tests for response mode switching based on Accept header.
|
||||
[AC-AISVC-06] Tests for automatic switching between JSON and SSE streaming modes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
class TestAcceptHeaderSwitching:
|
||||
"""
|
||||
[AC-AISVC-06] Test cases for Accept header based response mode switching.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_request_body(self):
|
||||
return {
|
||||
"sessionId": "test_session_001",
|
||||
"currentMessage": "Hello, how are you?",
|
||||
"channelType": "wechat",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {"X-Tenant-Id": "tenant_001"}
|
||||
|
||||
def test_json_response_with_default_accept(
|
||||
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-06] Test that default Accept header returns JSON response.
|
||||
"""
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=valid_request_body,
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
data = response.json()
|
||||
assert "reply" in data
|
||||
assert "confidence" in data
|
||||
assert "shouldTransfer" in data
|
||||
|
||||
def test_json_response_with_application_json_accept(
|
||||
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-06] Test that Accept: application/json returns JSON response.
|
||||
"""
|
||||
headers = {**valid_headers, "Accept": "application/json"}
|
||||
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=valid_request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
data = response.json()
|
||||
assert "reply" in data
|
||||
assert "confidence" in data
|
||||
assert "shouldTransfer" in data
|
||||
|
||||
def test_sse_response_with_text_event_stream_accept(
|
||||
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-06] Test that Accept: text/event-stream returns SSE response.
|
||||
"""
|
||||
headers = {**valid_headers, "Accept": "text/event-stream"}
|
||||
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=valid_request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
|
||||
content = response.text
|
||||
assert "event: message" in content
|
||||
assert "event: final" in content
|
||||
|
||||
def test_sse_response_event_sequence(
|
||||
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-07, AC-AISVC-08] Test that SSE events follow proper sequence.
|
||||
message* -> final -> close
|
||||
"""
|
||||
headers = {**valid_headers, "Accept": "text/event-stream"}
|
||||
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=valid_request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
content = response.text
|
||||
|
||||
assert "event:message" in content or "event: message" in content, f"Expected message event in: {content[:500]}"
|
||||
assert "event:final" in content or "event: final" in content, f"Expected final event in: {content[:500]}"
|
||||
|
||||
message_idx = content.find("event:message")
|
||||
if message_idx == -1:
|
||||
message_idx = content.find("event: message")
|
||||
final_idx = content.find("event:final")
|
||||
if final_idx == -1:
|
||||
final_idx = content.find("event: final")
|
||||
|
||||
assert final_idx > message_idx, "final event should come after message events"
|
||||
|
||||
def test_missing_tenant_id_returns_400(
|
||||
self, client: TestClient, valid_request_body: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
|
||||
"""
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=valid_request_body,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == "MISSING_TENANT_ID"
|
||||
assert "message" in data
|
||||
|
||||
def test_invalid_channel_type_returns_400(
|
||||
self, client: TestClient, valid_headers: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-03] Test that invalid channel type returns 400 error.
|
||||
"""
|
||||
invalid_body = {
|
||||
"sessionId": "test_session_001",
|
||||
"currentMessage": "Hello",
|
||||
"channelType": "invalid_channel",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=invalid_body,
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_missing_required_fields_returns_400(
|
||||
self, client: TestClient, valid_headers: dict
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-03] Test that missing required fields return 400 error.
|
||||
"""
|
||||
incomplete_body = {
|
||||
"sessionId": "test_session_001",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/ai/chat",
|
||||
json=incomplete_body,
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""
|
||||
[AC-AISVC-20] Test cases for health check endpoint.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
def test_health_check_returns_200(self, client: TestClient):
|
||||
"""
|
||||
[AC-AISVC-20] Test that health check returns 200 with status.
|
||||
"""
|
||||
response = client.get("/ai/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
|
||||
|
||||
class TestSSEStateMachine:
|
||||
"""
|
||||
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_transitions(self):
|
||||
from app.core.sse import SSEState, SSEStateMachine
|
||||
|
||||
state_machine = SSEStateMachine()
|
||||
|
||||
assert state_machine.state == SSEState.INIT
|
||||
|
||||
success = await state_machine.transition_to_streaming()
|
||||
assert success is True
|
||||
assert state_machine.state == SSEState.STREAMING
|
||||
|
||||
assert state_machine.can_send_message() is True
|
||||
|
||||
success = await state_machine.transition_to_final()
|
||||
assert success is True
|
||||
assert state_machine.state == SSEState.FINAL_SENT
|
||||
|
||||
assert state_machine.can_send_message() is False
|
||||
|
||||
await state_machine.close()
|
||||
assert state_machine.state == SSEState.CLOSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_transition_from_streaming(self):
|
||||
from app.core.sse import SSEState, SSEStateMachine
|
||||
|
||||
state_machine = SSEStateMachine()
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
success = await state_machine.transition_to_error()
|
||||
assert success is True
|
||||
assert state_machine.state == SSEState.ERROR_SENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_transition_to_final_from_init(self):
|
||||
from app.core.sse import SSEStateMachine
|
||||
|
||||
state_machine = SSEStateMachine()
|
||||
|
||||
success = await state_machine.transition_to_final()
|
||||
assert success is False
|
||||
|
||||
|
||||
class TestMiddleware:
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-12] Test cases for middleware.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
def test_tenant_context_extraction(
|
||||
self, client: TestClient
|
||||
):
|
||||
"""
|
||||
[AC-AISVC-10] Test that X-Tenant-Id is properly extracted and used.
|
||||
"""
|
||||
headers = {"X-Tenant-Id": "tenant_test_123"}
|
||||
body = {
|
||||
"sessionId": "session_001",
|
||||
"currentMessage": "Test message",
|
||||
"channelType": "wechat",
|
||||
}
|
||||
|
||||
response = client.post("/ai/chat", json=body, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_health_endpoint_bypasses_tenant_check(
|
||||
self, client: TestClient
|
||||
):
|
||||
"""
|
||||
Test that health endpoint doesn't require X-Tenant-Id.
|
||||
"""
|
||||
response = client.get("/ai/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
|
@ -0,0 +1,319 @@
|
|||
"""
|
||||
Unit tests for LLM Adapter.
|
||||
[AC-AISVC-02, AC-AISVC-06] Tests for LLM client interface.
|
||||
|
||||
Tests cover:
|
||||
- Non-streaming generation
|
||||
- Streaming generation
|
||||
- Error handling
|
||||
- Retry logic
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.llm.base import LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.openai_client import (
|
||||
LLMException,
|
||||
OpenAIClient,
|
||||
TimeoutException,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Mock settings for testing."""
|
||||
settings = MagicMock()
|
||||
settings.llm_api_key = "test-api-key"
|
||||
settings.llm_base_url = "https://api.openai.com/v1"
|
||||
settings.llm_model = "gpt-4o-mini"
|
||||
settings.llm_max_tokens = 2048
|
||||
settings.llm_temperature = 0.7
|
||||
settings.llm_timeout_seconds = 30
|
||||
settings.llm_max_retries = 3
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_client(mock_settings):
|
||||
"""Create LLM client with mocked settings."""
|
||||
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
|
||||
client = OpenAIClient()
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Sample chat messages for testing."""
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generate_response():
|
||||
"""Sample non-streaming response from OpenAI API."""
|
||||
return {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-4o-mini",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm doing well, thank you for asking!",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_chunks():
|
||||
"""Sample streaming chunks from OpenAI API."""
|
||||
return [
|
||||
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n",
|
||||
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n",
|
||||
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" How\"},\"finish_reason\":null}]}\n",
|
||||
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" can I help?\"},\"finish_reason\":\"stop\"}]}\n",
|
||||
"data: [DONE]\n",
|
||||
]
|
||||
|
||||
|
||||
class TestOpenAIClientGenerate:
|
||||
"""Tests for non-streaming generation. [AC-AISVC-02]"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_success(self, llm_client, mock_messages, mock_generate_response):
|
||||
"""[AC-AISVC-02] Test successful non-streaming generation."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_generate_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
llm_client, "_get_client"
|
||||
) as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await llm_client.generate(mock_messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! I'm doing well, thank you for asking!"
|
||||
assert result.model == "gpt-4o-mini"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["total_tokens"] == 35
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_custom_config(self, llm_client, mock_messages, mock_generate_response):
|
||||
"""[AC-AISVC-02] Test generation with custom configuration."""
|
||||
custom_config = LLMConfig(
|
||||
model="gpt-4",
|
||||
max_tokens=1024,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {**mock_generate_response, "model": "gpt-4"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.object(llm_client, "_get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await llm_client.generate(mock_messages, config=custom_config)
|
||||
|
||||
assert result.model == "gpt-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_timeout_error(self, llm_client, mock_messages):
|
||||
"""[AC-AISVC-02] Test timeout error handling."""
|
||||
with patch.object(llm_client, "_get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(TimeoutException):
|
||||
await llm_client.generate(mock_messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_api_error(self, llm_client, mock_messages):
|
||||
"""[AC-AISVC-02] Test API error handling."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = '{"error": {"message": "Invalid API key"}}'
|
||||
mock_response.json.return_value = {"error": {"message": "Invalid API key"}}
|
||||
|
||||
http_error = httpx.HTTPStatusError(
|
||||
"Unauthorized",
|
||||
request=MagicMock(),
|
||||
response=mock_response,
|
||||
)
|
||||
|
||||
with patch.object(llm_client, "_get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=http_error)
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(LLMException) as exc_info:
|
||||
await llm_client.generate(mock_messages)
|
||||
|
||||
assert "Invalid API key" in str(exc_info.value.message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_malformed_response(self, llm_client, mock_messages):
|
||||
"""[AC-AISVC-02] Test handling of malformed response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"invalid": "response"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.object(llm_client, "_get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(LLMException):
|
||||
await llm_client.generate(mock_messages)
|
||||
|
||||
|
||||
class MockAsyncStreamContext:
|
||||
"""Mock async context manager for streaming."""
|
||||
|
||||
def __init__(self, response):
|
||||
self._response = response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self._response
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class TestOpenAIClientStreamGenerate:
|
||||
"""Tests for streaming generation. [AC-AISVC-06, AC-AISVC-07]"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_generate_success(self, llm_client, mock_messages, mock_stream_chunks):
|
||||
"""[AC-AISVC-06, AC-AISVC-07] Test successful streaming generation."""
|
||||
async def mock_aiter_lines():
|
||||
for chunk in mock_stream_chunks:
|
||||
yield chunk
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.aiter_lines = mock_aiter_lines
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.stream = MagicMock(return_value=MockAsyncStreamContext(mock_response))
|
||||
|
||||
with patch.object(llm_client, "_get_client", return_value=mock_client):
|
||||
chunks = []
|
||||
async for chunk in llm_client.stream_generate(mock_messages):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 4
|
||||
assert chunks[0].delta == "Hello"
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_generate_timeout_error(self, llm_client, mock_messages):
|
||||
"""[AC-AISVC-06] Test streaming timeout error handling."""
|
||||
mock_client = AsyncMock()
|
||||
|
||||
class TimeoutContext:
|
||||
async def __aenter__(self):
|
||||
raise httpx.TimeoutException("Timeout")
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
mock_client.stream = MagicMock(return_value=TimeoutContext())
|
||||
|
||||
with patch.object(llm_client, "_get_client", return_value=mock_client):
|
||||
with pytest.raises(TimeoutException):
|
||||
async for _ in llm_client.stream_generate(mock_messages):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_generate_api_error(self, llm_client, mock_messages):
|
||||
"""[AC-AISVC-06] Test streaming API error handling."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_response.json.return_value = {"error": {"message": "Internal Server Error"}}
|
||||
|
||||
http_error = httpx.HTTPStatusError(
|
||||
"Internal Server Error",
|
||||
request=MagicMock(),
|
||||
response=mock_response,
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
||||
class ErrorContext:
|
||||
async def __aenter__(self):
|
||||
raise http_error
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
mock_client.stream = MagicMock(return_value=ErrorContext())
|
||||
|
||||
with patch.object(llm_client, "_get_client", return_value=mock_client):
|
||||
with pytest.raises(LLMException):
|
||||
async for _ in llm_client.stream_generate(mock_messages):
|
||||
pass
|
||||
|
||||
|
||||
class TestOpenAIClientConfig:
|
||||
"""Tests for LLM configuration."""
|
||||
|
||||
def test_default_config(self, mock_settings):
|
||||
"""Test default configuration from settings."""
|
||||
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
|
||||
client = OpenAIClient()
|
||||
|
||||
assert client._model == "gpt-4o-mini"
|
||||
assert client._default_config.max_tokens == 2048
|
||||
assert client._default_config.temperature == 0.7
|
||||
|
||||
def test_custom_config_override(self, mock_settings):
|
||||
"""Test custom configuration override."""
|
||||
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
|
||||
client = OpenAIClient(
|
||||
api_key="custom-key",
|
||||
base_url="https://custom.api.com/v1",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
assert client._api_key == "custom-key"
|
||||
assert client._base_url == "https://custom.api.com/v1"
|
||||
assert client._model == "gpt-4"
|
||||
|
||||
|
||||
class TestOpenAIClientClose:
|
||||
"""Tests for client cleanup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_client(self, llm_client):
|
||||
"""Test client close releases resources."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.aclose = AsyncMock()
|
||||
llm_client._client = mock_client
|
||||
|
||||
await llm_client.close()
|
||||
|
||||
mock_client.aclose.assert_called_once()
|
||||
assert llm_client._client is None
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Unit tests for Memory service.
|
||||
[AC-AISVC-10, AC-AISVC-11, AC-AISVC-13] Tests for multi-tenant session and message management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import ChatMessage, ChatSession
|
||||
from app.services.memory import MemoryService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create a mock AsyncSession."""
|
||||
session = AsyncMock(spec=AsyncSession)
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_service(mock_session):
|
||||
"""Create MemoryService with mocked session."""
|
||||
return MemoryService(mock_session)
|
||||
|
||||
|
||||
class TestMemoryServiceTenantIsolation:
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in memory service.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_session_tenant_isolation(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-11] Different tenants with same session_id should have separate sessions.
|
||||
"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
session1 = await memory_service.get_or_create_session(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
)
|
||||
session2 = await memory_service.get_or_create_session(
|
||||
tenant_id="tenant_b",
|
||||
session_id="session_123",
|
||||
)
|
||||
|
||||
assert session1.tenant_id == "tenant_a"
|
||||
assert session2.tenant_id == "tenant_b"
|
||||
assert session1.session_id == "session_123"
|
||||
assert session2.session_id == "session_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_history_tenant_isolation(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-11] Loading history should only return messages for the specific tenant.
|
||||
"""
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = [
|
||||
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Hello"),
|
||||
]
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
messages = await memory_service.load_history(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].tenant_id == "tenant_a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_message_tenant_scoped(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-13] Appended messages should be scoped to tenant.
|
||||
"""
|
||||
message = await memory_service.append_message(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
role="user",
|
||||
content="Test message",
|
||||
)
|
||||
|
||||
assert message.tenant_id == "tenant_a"
|
||||
assert message.session_id == "session_123"
|
||||
assert message.role == "user"
|
||||
assert message.content == "Test message"
|
||||
|
||||
|
||||
class TestMemoryServiceSessionManagement:
|
||||
"""
|
||||
[AC-AISVC-13] Tests for session-based memory management.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_session(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-13] Should return existing session if it exists.
|
||||
"""
|
||||
existing_session = ChatSession(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_session
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
session = await memory_service.get_or_create_session(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
)
|
||||
|
||||
assert session.tenant_id == "tenant_a"
|
||||
assert session.session_id == "session_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_session(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-13] Should create new session if it doesn't exist.
|
||||
"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
session = await memory_service.get_or_create_session(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_new",
|
||||
channel_type="wechat",
|
||||
metadata={"user_id": "user_123"},
|
||||
)
|
||||
|
||||
assert session.tenant_id == "tenant_a"
|
||||
assert session.session_id == "session_new"
|
||||
assert session.channel_type == "wechat"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_multiple_messages(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-13] Should append multiple messages in batch.
|
||||
"""
|
||||
messages_data = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
messages = await memory_service.append_messages(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
messages=messages_data,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == "user"
|
||||
assert messages[1].role == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_history_with_limit(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-13] Should limit the number of messages returned.
|
||||
"""
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = [
|
||||
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content=f"Msg {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
messages = await memory_service.load_history(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert len(messages) == 5
|
||||
|
||||
|
||||
class TestMemoryServiceClearHistory:
|
||||
"""
|
||||
[AC-AISVC-13] Tests for clearing session history.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history_tenant_scoped(self, memory_service, mock_session):
|
||||
"""
|
||||
[AC-AISVC-11] Clearing history should only affect the specified tenant's messages.
|
||||
"""
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = [
|
||||
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Msg 1"),
|
||||
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="assistant", content="Msg 2"),
|
||||
]
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
count = await memory_service.clear_history(
|
||||
tenant_id="tenant_a",
|
||||
session_id="session_123",
|
||||
)
|
||||
|
||||
assert count == 2
|
||||
Loading…
Reference in New Issue