[AC-AISVC-02, AC-AISVC-16] 多个需求合并 #1

Merged
MerCry merged 45 commits from feature/prompt-unification-and-logging into main 2026-02-25 17:17:35 +00:00
31 changed files with 3060 additions and 0 deletions
Showing only changes of commit 0a167d69f0 - Show all commits

26
ai-service/README.md Normal file
View File

@ -0,0 +1,26 @@
# AI Service
Python AI Service for intelligent chat with RAG support.
## Features
- Multi-tenant isolation via X-Tenant-Id header
- SSE streaming support via Accept: text/event-stream
- RAG-powered responses with confidence scoring
## Installation
```bash
pip install -e ".[dev]"
```
## Running
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8080
```
## API Endpoints
- `POST /ai/chat` - Generate AI reply
- `GET /ai/health` - Health check

View File

@ -0,0 +1,4 @@
"""
AI Service - Python AI Middle Platform
[AC-AISVC-01] FastAPI-based AI chat service with multi-tenant support.
"""

View File

@ -0,0 +1,8 @@
"""
API module for AI Service.
"""
from app.api.chat import router as chat_router
from app.api.health import router as health_router
__all__ = ["chat_router", "health_router"]

127
ai-service/app/api/chat.py Normal file
View File

@ -0,0 +1,127 @@
"""
Chat endpoint for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Main chat endpoint with streaming/non-streaming modes.
"""
import logging
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Header, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from app.core.middleware import get_response_mode, is_sse_request
from app.core.sse import create_error_event
from app.core.tenant import get_tenant_id
from app.models import ChatRequest, ChatResponse, ErrorResponse
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
logger = logging.getLogger(__name__)
router = APIRouter(tags=["AI Chat"])
@router.post(
"/ai/chat",
operation_id="generateReply",
summary="Generate AI reply",
description="""
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Generate AI reply based on user message.
Response mode is determined by Accept header:
- Accept: text/event-stream -> SSE streaming response
- Other -> JSON response
""",
responses={
200: {
"description": "Success - JSON or SSE stream",
"content": {
"application/json": {"schema": {"$ref": "#/components/schemas/ChatResponse"}},
"text/event-stream": {"schema": {"type": "string"}},
},
},
400: {"description": "Invalid request", "model": ErrorResponse},
500: {"description": "Internal error", "model": ErrorResponse},
503: {"description": "Service unavailable", "model": ErrorResponse},
},
)
async def generate_reply(
request: Request,
chat_request: ChatRequest,
accept: Annotated[str | None, Header()] = None,
orchestrator: OrchestratorService = Depends(get_orchestrator_service),
) -> Any:
"""
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
Based on Accept header:
- text/event-stream: Returns SSE stream with message/final/error events
- Other: Returns JSON ChatResponse
"""
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
logger.info(
f"[AC-AISVC-06] Processing chat request: tenant={tenant_id}, "
f"session={chat_request.session_id}, mode={get_response_mode(request)}"
)
if is_sse_request(request):
return await _handle_streaming_request(tenant_id, chat_request, orchestrator)
else:
return await _handle_json_request(tenant_id, chat_request, orchestrator)
async def _handle_json_request(
tenant_id: str,
chat_request: ChatRequest,
orchestrator: OrchestratorService,
) -> JSONResponse:
"""
[AC-AISVC-02] Handle non-streaming JSON request.
Returns ChatResponse with reply, confidence, shouldTransfer.
"""
logger.info(f"[AC-AISVC-02] Processing JSON request for tenant={tenant_id}")
try:
response = await orchestrator.generate(tenant_id, chat_request)
return JSONResponse(
content=response.model_dump(exclude_none=True, by_alias=True),
)
except Exception as e:
logger.error(f"[AC-AISVC-04] Error generating response: {e}")
from app.core.exceptions import AIServiceException, ErrorCode
if isinstance(e, AIServiceException):
raise e
from app.core.exceptions import AIServiceException
raise AIServiceException(
code=ErrorCode.INTERNAL_ERROR,
message=str(e),
)
async def _handle_streaming_request(
tenant_id: str,
chat_request: ChatRequest,
orchestrator: OrchestratorService,
) -> EventSourceResponse:
"""
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
Yields message events followed by final or error event.
"""
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
async def event_generator():
try:
async for event in orchestrator.generate_stream(tenant_id, chat_request):
yield event
except Exception as e:
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
yield create_error_event(
code="STREAMING_ERROR",
message=str(e),
)
return EventSourceResponse(event_generator(), ping=15)

View File

@ -0,0 +1,30 @@
"""
Health check endpoint.
[AC-AISVC-20] Health check for service monitoring.
"""
from fastapi import APIRouter, status
from fastapi.responses import JSONResponse
router = APIRouter(tags=["Health"])
@router.get(
"/ai/health",
operation_id="healthCheck",
summary="Health check",
description="[AC-AISVC-20] Check if AI service is healthy",
responses={
200: {"description": "Service is healthy"},
503: {"description": "Service is unhealthy"},
},
)
async def health_check() -> JSONResponse:
"""
[AC-AISVC-20] Health check endpoint.
Returns 200 with status if healthy, 503 if not.
"""
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"status": "healthy"},
)

View File

@ -0,0 +1,19 @@
"""
Core module - Configuration, dependencies, and utilities.
[AC-AISVC-01, AC-AISVC-10, AC-AISVC-11] Core infrastructure components.
"""
from app.core.config import Settings, get_settings
from app.core.database import async_session_maker, get_session, init_db, close_db
from app.core.qdrant_client import QdrantClient, get_qdrant_client
__all__ = [
"Settings",
"get_settings",
"async_session_maker",
"get_session",
"init_db",
"close_db",
"QdrantClient",
"get_qdrant_client",
]

View File

@ -0,0 +1,54 @@
"""
Configuration management for AI Service.
[AC-AISVC-01] Centralized configuration with environment variable support.
"""
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="AI_SERVICE_", env_file=".env", extra="ignore")
app_name: str = "AI Service"
app_version: str = "0.1.0"
debug: bool = False
host: str = "0.0.0.0"
port: int = 8080
request_timeout_seconds: int = 20
sse_ping_interval_seconds: int = 15
log_level: str = "INFO"
llm_provider: str = "openai"
llm_api_key: str = ""
llm_base_url: str = "https://api.openai.com/v1"
llm_model: str = "gpt-4o-mini"
llm_max_tokens: int = 2048
llm_temperature: float = 0.7
llm_timeout_seconds: int = 30
llm_max_retries: int = 3
database_url: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/ai_service"
database_pool_size: int = 10
database_max_overflow: int = 20
qdrant_url: str = "http://localhost:6333"
qdrant_collection_prefix: str = "kb_"
qdrant_vector_size: int = 1536
rag_top_k: int = 5
rag_score_threshold: float = 0.7
rag_min_hits: int = 1
rag_max_evidence_tokens: int = 2000
confidence_threshold_low: float = 0.5
max_history_tokens: int = 4000
@lru_cache
def get_settings() -> Settings:
return Settings()

View File

@ -0,0 +1,67 @@
"""
Database client for AI Service.
[AC-AISVC-11] PostgreSQL database with SQLModel for multi-tenant data isolation.
"""
import logging
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from sqlmodel import SQLModel
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
engine = create_async_engine(
settings.database_url,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
echo=settings.debug,
pool_pre_ping=True,
)
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def init_db() -> None:
"""
[AC-AISVC-11] Initialize database tables.
Creates all tables defined in SQLModel metadata.
"""
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("[AC-AISVC-11] Database tables initialized")
async def close_db() -> None:
"""
Close database connections.
"""
await engine.dispose()
logger.info("Database connections closed")
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""
[AC-AISVC-11] Dependency injection for database session.
Ensures proper session lifecycle management.
"""
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()

View File

@ -0,0 +1,99 @@
"""
Exception handling for AI Service.
[AC-AISVC-03, AC-AISVC-04, AC-AISVC-05] Structured error responses.
"""
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from app.models import ErrorCode, ErrorResponse
class AIServiceException(Exception):
def __init__(
self,
code: ErrorCode,
message: str,
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
details: list[dict] | None = None,
):
self.code = code
self.message = message
self.status_code = status_code
self.details = details
super().__init__(message)
class MissingTenantIdException(AIServiceException):
def __init__(self, message: str = "Missing required header: X-Tenant-Id"):
super().__init__(
code=ErrorCode.MISSING_TENANT_ID,
message=message,
status_code=status.HTTP_400_BAD_REQUEST,
)
class InvalidRequestException(AIServiceException):
def __init__(self, message: str, details: list[dict] | None = None):
super().__init__(
code=ErrorCode.INVALID_REQUEST,
message=message,
status_code=status.HTTP_400_BAD_REQUEST,
details=details,
)
class ServiceUnavailableException(AIServiceException):
def __init__(self, message: str = "Service temporarily unavailable"):
super().__init__(
code=ErrorCode.SERVICE_UNAVAILABLE,
message=message,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
class TimeoutException(AIServiceException):
def __init__(self, message: str = "Request timeout"):
super().__init__(
code=ErrorCode.TIMEOUT,
message=message,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
async def ai_service_exception_handler(request: Request, exc: AIServiceException) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
code=exc.code.value,
message=exc.message,
details=exc.details,
).model_dump(exclude_none=True),
)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
if exc.status_code == status.HTTP_400_BAD_REQUEST:
code = ErrorCode.INVALID_REQUEST
elif exc.status_code == status.HTTP_503_SERVICE_UNAVAILABLE:
code = ErrorCode.SERVICE_UNAVAILABLE
else:
code = ErrorCode.INTERNAL_ERROR
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
code=code.value,
message=exc.detail or "An error occurred",
).model_dump(exclude_none=True),
)
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=ErrorResponse(
code=ErrorCode.INTERNAL_ERROR.value,
message="An unexpected error occurred",
).model_dump(exclude_none=True),
)

View File

@ -0,0 +1,74 @@
"""
Middleware for AI Service.
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
"""
import logging
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
from app.core.tenant import clear_tenant_context, set_tenant_context
logger = logging.getLogger(__name__)
TENANT_ID_HEADER = "X-Tenant-Id"
ACCEPT_HEADER = "Accept"
SSE_CONTENT_TYPE = "text/event-stream"
class TenantContextMiddleware(BaseHTTPMiddleware):
"""
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
Injects tenant context into request state for downstream processing.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
clear_tenant_context()
if request.url.path == "/ai/health":
return await call_next(request)
tenant_id = request.headers.get(TENANT_ID_HEADER)
if not tenant_id or not tenant_id.strip():
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorResponse(
code=ErrorCode.MISSING_TENANT_ID.value,
message="Missing required header: X-Tenant-Id",
).model_dump(exclude_none=True),
)
set_tenant_context(tenant_id.strip())
request.state.tenant_id = tenant_id.strip()
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}")
try:
response = await call_next(request)
finally:
clear_tenant_context()
return response
def is_sse_request(request: Request) -> bool:
"""
[AC-AISVC-06] Check if the request expects SSE streaming response.
Based on Accept header: text/event-stream indicates SSE mode.
"""
accept_header = request.headers.get(ACCEPT_HEADER, "")
return SSE_CONTENT_TYPE in accept_header
def get_response_mode(request: Request) -> str:
"""
[AC-AISVC-06] Determine response mode based on Accept header.
Returns 'streaming' for SSE, 'json' for regular JSON response.
"""
return "streaming" if is_sse_request(request) else "json"

View File

@ -0,0 +1,175 @@
"""
Qdrant client for AI Service.
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
"""
import logging
from typing import Any
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class QdrantClient:
"""
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
Collection naming: kb_{tenantId} for tenant isolation.
"""
def __init__(self):
self._client: AsyncQdrantClient | None = None
self._collection_prefix = settings.qdrant_collection_prefix
self._vector_size = settings.qdrant_vector_size
async def get_client(self) -> AsyncQdrantClient:
"""Get or create Qdrant client instance."""
if self._client is None:
self._client = AsyncQdrantClient(url=settings.qdrant_url)
logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}")
return self._client
async def close(self) -> None:
"""Close Qdrant client connection."""
if self._client:
await self._client.close()
self._client = None
logger.info("Qdrant client connection closed")
def get_collection_name(self, tenant_id: str) -> str:
"""
[AC-AISVC-10] Get collection name for a tenant.
Naming convention: kb_{tenantId}
"""
return f"{self._collection_prefix}{tenant_id}"
async def ensure_collection_exists(self, tenant_id: str) -> bool:
"""
[AC-AISVC-10] Ensure collection exists for tenant.
Note: MVP uses pre-provisioned collections, this is for development/testing.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
collections = await client.get_collections()
exists = any(c.name == collection_name for c in collections.collections)
if not exists:
await client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
),
)
logger.info(
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}")
return False
async def upsert_vectors(
self,
tenant_id: str,
points: list[PointStruct],
) -> bool:
"""
[AC-AISVC-10] Upsert vectors into tenant's collection.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
await client.upsert(
collection_name=collection_name,
points=points,
)
logger.info(
f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
return False
async def search(
self,
tenant_id: str,
query_vector: list[float],
limit: int = 5,
score_threshold: float | None = None,
) -> list[dict[str, Any]]:
"""
[AC-AISVC-10] Search vectors in tenant's collection.
Returns results with score >= score_threshold if specified.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
results = await client.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold,
)
hits = [
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
]
logger.info(
f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}"
)
return hits
except Exception as e:
logger.error(f"[AC-AISVC-10] Error searching vectors: {e}")
return []
async def delete_collection(self, tenant_id: str) -> bool:
"""
[AC-AISVC-10] Delete tenant's collection.
Used when tenant is removed.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
await client.delete_collection(collection_name=collection_name)
logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}")
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error deleting collection: {e}")
return False
_qdrant_client: QdrantClient | None = None
async def get_qdrant_client() -> QdrantClient:
"""Get or create Qdrant client instance."""
global _qdrant_client
if _qdrant_client is None:
_qdrant_client = QdrantClient()
return _qdrant_client
async def close_qdrant_client() -> None:
"""Close Qdrant client connection."""
global _qdrant_client
if _qdrant_client:
await _qdrant_client.close()
_qdrant_client = None

170
ai-service/app/core/sse.py Normal file
View File

@ -0,0 +1,170 @@
"""
SSE utilities for AI Service.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine.
"""
import asyncio
import json
import logging
from enum import Enum
from typing import Any, AsyncGenerator
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from app.core.config import get_settings
from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent
logger = logging.getLogger(__name__)
class SSEState(str, Enum):
INIT = "INIT"
STREAMING = "STREAMING"
FINAL_SENT = "FINAL_SENT"
ERROR_SENT = "ERROR_SENT"
CLOSED = "CLOSED"
class SSEStateMachine:
"""
[AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence.
State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED
"""
def __init__(self):
self._state = SSEState.INIT
self._lock = asyncio.Lock()
@property
def state(self) -> SSEState:
return self._state
async def transition_to_streaming(self) -> bool:
async with self._lock:
if self._state == SSEState.INIT:
self._state = SSEState.STREAMING
logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING")
return True
return False
async def transition_to_final(self) -> bool:
async with self._lock:
if self._state == SSEState.STREAMING:
self._state = SSEState.FINAL_SENT
logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT")
return True
return False
async def transition_to_error(self) -> bool:
async with self._lock:
if self._state in (SSEState.INIT, SSEState.STREAMING):
self._state = SSEState.ERROR_SENT
logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT")
return True
return False
async def close(self) -> None:
async with self._lock:
self._state = SSEState.CLOSED
logger.debug("SSE state transition: -> CLOSED")
def can_send_message(self) -> bool:
return self._state == SSEState.STREAMING
def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent:
"""Format data as SSE event."""
return ServerSentEvent(
event=event_type.value,
data=json.dumps(data, ensure_ascii=False),
)
def create_message_event(delta: str) -> ServerSentEvent:
"""[AC-AISVC-07] Create a message event with incremental content."""
event_data = SSEMessageEvent(delta=delta)
return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump())
def create_final_event(
reply: str,
confidence: float,
should_transfer: bool,
transfer_reason: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ServerSentEvent:
"""[AC-AISVC-08] Create a final event with complete response."""
event_data = SSEFinalEvent(
reply=reply,
confidence=confidence,
should_transfer=should_transfer,
transfer_reason=transfer_reason,
metadata=metadata,
)
return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True))
def create_error_event(
code: str,
message: str,
details: list[dict[str, Any]] | None = None,
) -> ServerSentEvent:
"""[AC-AISVC-09] Create an error event."""
event_data = SSEErrorEvent(
code=code,
message=message,
details=details,
)
return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True))
async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]:
"""
[AC-AISVC-06] Generate ping comments for SSE keep-alive.
Sends ': ping' as comment lines (not events) to keep connection alive.
"""
while True:
await asyncio.sleep(interval_seconds)
yield ": ping\n\n"
class SSEResponseBuilder:
"""
Builder for SSE response with proper event sequencing and ping keep-alive.
"""
def __init__(self):
self._state_machine = SSEStateMachine()
self._settings = get_settings()
async def build_response(
self,
content_generator: AsyncGenerator[ServerSentEvent, None],
) -> EventSourceResponse:
"""
Build SSE response with ping keep-alive mechanism.
[AC-AISVC-06] Implements ping keep-alive to prevent connection timeout.
"""
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
await self._state_machine.transition_to_streaming()
try:
async for event in content_generator:
if self._state_machine.can_send_message():
yield event
else:
break
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}")
if await self._state_machine.transition_to_error():
yield create_error_event(
code="STREAMING_ERROR",
message=str(e),
)
finally:
await self._state_machine.close()
return EventSourceResponse(
event_generator(),
ping=self._settings.sse_ping_interval_seconds,
)

View File

@ -0,0 +1,31 @@
"""
Tenant context management.
[AC-AISVC-10, AC-AISVC-12] Multi-tenant isolation via X-Tenant-Id header.
"""
from contextvars import ContextVar
from dataclasses import dataclass
tenant_context: ContextVar["TenantContext | None"] = ContextVar("tenant_context", default=None)
@dataclass
class TenantContext:
tenant_id: str
def set_tenant_context(tenant_id: str) -> None:
tenant_context.set(TenantContext(tenant_id=tenant_id))
def get_tenant_context() -> TenantContext | None:
return tenant_context.get()
def get_tenant_id() -> str | None:
ctx = get_tenant_context()
return ctx.tenant_id if ctx else None
def clear_tenant_context() -> None:
tenant_context.set(None)

123
ai-service/app/main.py Normal file
View File

@ -0,0 +1,123 @@
"""
Main FastAPI application for AI Service.
[AC-AISVC-01] Entry point with middleware and exception handlers.
"""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api import chat_router, health_router
from app.core.config import get_settings
from app.core.database import close_db, init_db
from app.core.exceptions import (
AIServiceException,
ErrorCode,
ErrorResponse,
ai_service_exception_handler,
generic_exception_handler,
http_exception_handler,
)
from app.core.middleware import TenantContextMiddleware
from app.core.qdrant_client import close_qdrant_client
settings = get_settings()
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
[AC-AISVC-01, AC-AISVC-11] Application lifespan manager.
Handles startup and shutdown of database and external connections.
"""
logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}")
try:
await init_db()
logger.info("[AC-AISVC-11] Database initialized successfully")
except Exception as e:
logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}")
yield
await close_db()
await close_qdrant_client()
logger.info(f"Shutting down {settings.app_name}")
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description="""
Python AI Service for intelligent chat with RAG support.
## Features
- Multi-tenant isolation via X-Tenant-Id header
- SSE streaming support via Accept: text/event-stream
- RAG-powered responses with confidence scoring
## Response Modes
- **JSON**: Default response mode (Accept: application/json or no Accept header)
- **SSE Streaming**: Set Accept: text/event-stream for streaming responses
""",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(TenantContextMiddleware)
app.add_exception_handler(AIServiceException, ai_service_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
[AC-AISVC-03] Handle request validation errors with structured response.
"""
logger.warning(f"[AC-AISVC-03] Request validation error: {exc.errors()}")
error_response = ErrorResponse(
code=ErrorCode.INVALID_REQUEST.value,
message="Request validation failed",
details=[{"loc": list(err["loc"]), "msg": err["msg"], "type": err["type"]} for err in exc.errors()],
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=error_response.model_dump(exclude_none=True),
)
app.include_router(health_router)
app.include_router(chat_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)

View File

@ -0,0 +1,88 @@
"""
Data models for AI Service.
[AC-AISVC-02] Request/Response models aligned with OpenAPI contract.
[AC-AISVC-13] Entity models for database persistence.
"""
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ChannelType(str, Enum):
WECHAT = "wechat"
DOUYIN = "douyin"
JD = "jd"
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
class ChatMessage(BaseModel):
role: Role = Field(..., description="Message role: user or assistant")
content: str = Field(..., description="Message content")
class ChatRequest(BaseModel):
session_id: str = Field(..., alias="sessionId", description="Session ID for conversation tracking")
current_message: str = Field(..., alias="currentMessage", description="Current user message")
channel_type: ChannelType = Field(..., alias="channelType", description="Channel type: wechat, douyin, jd")
history: list[ChatMessage] | None = Field(default=None, description="Optional conversation history")
metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata")
model_config = {"populate_by_name": True}
class ChatResponse(BaseModel):
reply: str = Field(..., description="AI generated reply content")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score between 0.0 and 1.0")
should_transfer: bool = Field(..., alias="shouldTransfer", description="Whether to suggest transfer to human agent")
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Reason for transfer suggestion")
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
model_config = {"populate_by_name": True}
class ErrorCode(str, Enum):
INVALID_REQUEST = "INVALID_REQUEST"
MISSING_TENANT_ID = "MISSING_TENANT_ID"
INTERNAL_ERROR = "INTERNAL_ERROR"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
TIMEOUT = "TIMEOUT"
LLM_ERROR = "LLM_ERROR"
RETRIEVAL_ERROR = "RETRIEVAL_ERROR"
class ErrorResponse(BaseModel):
code: str = Field(..., description="Error code")
message: str = Field(..., description="Error message")
details: list[dict[str, Any]] | None = Field(default=None, description="Detailed error information")
class SSEEventType(str, Enum):
MESSAGE = "message"
FINAL = "final"
ERROR = "error"
class SSEMessageEvent(BaseModel):
delta: str = Field(..., description="Incremental text content")
class SSEFinalEvent(BaseModel):
reply: str = Field(..., description="Complete AI reply")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
should_transfer: bool = Field(..., alias="shouldTransfer", description="Transfer suggestion")
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Transfer reason")
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
model_config = {"populate_by_name": True}
class SSEErrorEvent(BaseModel):
code: str = Field(..., description="Error code")
message: str = Field(..., description="Error message")
details: list[dict[str, Any]] | None = Field(default=None, description="Error details")

View File

@ -0,0 +1,74 @@
"""
Memory layer entities for AI Service.
[AC-AISVC-13] SQLModel entities for chat sessions and messages with tenant isolation.
"""
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import Column, JSON
from sqlmodel import Field, Index, SQLModel
class ChatSession(SQLModel, table=True):
"""
[AC-AISVC-13] Chat session entity with tenant isolation.
Primary key: (tenant_id, session_id) composite unique constraint.
"""
__tablename__ = "chat_sessions"
__table_args__ = (
Index("ix_chat_sessions_tenant_session", "tenant_id", "session_id", unique=True),
Index("ix_chat_sessions_tenant_id", "tenant_id"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
session_id: str = Field(..., description="Session ID for conversation tracking")
channel_type: str | None = Field(default=None, description="Channel type: wechat, douyin, jd")
metadata_: dict[str, Any] | None = Field(
default=None,
sa_column=Column("metadata", JSON, nullable=True),
description="Session metadata"
)
created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class ChatMessage(SQLModel, table=True):
"""
[AC-AISVC-13] Chat message entity with tenant isolation.
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
"""
__tablename__ = "chat_messages"
__table_args__ = (
Index("ix_chat_messages_tenant_session", "tenant_id", "session_id"),
Index("ix_chat_messages_tenant_session_created", "tenant_id", "session_id", "created_at"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
role: str = Field(..., description="Message role: user or assistant")
content: str = Field(..., description="Message content")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Message creation time")
class ChatSessionCreate(SQLModel):
"""Schema for creating a new chat session."""
tenant_id: str
session_id: str
channel_type: str | None = None
metadata_: dict[str, Any] | None = None
class ChatMessageCreate(SQLModel):
"""Schema for creating a new chat message."""
tenant_id: str
session_id: str
role: str
content: str

View File

@ -0,0 +1,9 @@
"""
Services module for AI Service.
[AC-AISVC-13, AC-AISVC-16] Core services for memory and retrieval.
"""
from app.services.memory import MemoryService
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
__all__ = ["MemoryService", "OrchestratorService", "get_orchestrator_service"]

View File

@ -0,0 +1,15 @@
"""
LLM Adapter module for AI Service.
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
"""
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
from app.services.llm.openai_client import OpenAIClient
__all__ = [
"LLMClient",
"LLMConfig",
"LLMResponse",
"LLMStreamChunk",
"OpenAIClient",
]

View File

@ -0,0 +1,115 @@
"""
Base LLM client interface.
[AC-AISVC-02, AC-AISVC-06] Abstract interface for LLM providers.
Design reference: design.md Section 8.1 - LLMClient interface
- generate(prompt, params) -> text
- stream_generate(prompt, params) -> iterator[delta]
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator
@dataclass
class LLMConfig:
"""
Configuration for LLM client.
[AC-AISVC-02] Supports configurable model parameters.
"""
model: str = "gpt-4o-mini"
max_tokens: int = 2048
temperature: float = 0.7
top_p: float = 1.0
timeout_seconds: int = 30
max_retries: int = 3
extra_params: dict[str, Any] = field(default_factory=dict)
@dataclass
class LLMResponse:
"""
Response from LLM generation.
[AC-AISVC-02] Contains generated content and metadata.
"""
content: str
model: str
usage: dict[str, int] = field(default_factory=dict)
finish_reason: str = "stop"
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class LLMStreamChunk:
"""
Streaming chunk from LLM.
[AC-AISVC-06, AC-AISVC-07] Incremental output for SSE streaming.
"""
delta: str
model: str
finish_reason: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class LLMClient(ABC):
"""
Abstract base class for LLM clients.
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for different LLM providers.
Design reference: design.md Section 8.2 - Plugin points
- OpenAICompatibleClient / LocalModelClient can be swapped
"""
@abstractmethod
async def generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns complete response for ChatResponse.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content and metadata.
Raises:
LLMException: If generation fails.
"""
pass
@abstractmethod
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Yields:
LLMStreamChunk with incremental content.
Raises:
LLMException: If generation fails.
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the client and release resources."""
pass

View File

@ -0,0 +1,319 @@
"""
OpenAI-compatible LLM client implementation.
[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API.
Design reference: design.md Section 8.1 - LLMClient interface
- Uses langchain-openai or official SDK pattern
- Supports generate and stream_generate
"""
import json
import logging
from typing import Any, AsyncGenerator
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import get_settings
from app.core.exceptions import AIServiceException, ErrorCode, ServiceUnavailableException, TimeoutException
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
logger = logging.getLogger(__name__)
class LLMException(AIServiceException):
"""Exception raised when LLM operations fail."""
def __init__(self, message: str, details: list[dict] | None = None):
super().__init__(
code=ErrorCode.LLM_ERROR,
message=message,
status_code=503,
details=details,
)
class OpenAIClient(LLMClient):
"""
OpenAI-compatible LLM client.
[AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API.
Supports:
- OpenAI API (official)
- OpenAI-compatible endpoints (Azure, local models, etc.)
"""
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
model: str | None = None,
default_config: LLMConfig | None = None,
):
settings = get_settings()
self._api_key = api_key or settings.llm_api_key
self._base_url = (base_url or settings.llm_base_url).rstrip("/")
self._model = model or settings.llm_model
self._default_config = default_config or LLMConfig(
model=self._model,
max_tokens=settings.llm_max_tokens,
temperature=settings.llm_temperature,
timeout_seconds=settings.llm_timeout_seconds,
max_retries=settings.llm_max_retries,
)
self._client: httpx.AsyncClient | None = None
def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout_seconds),
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
return self._client
def _build_request_body(
self,
messages: list[dict[str, str]],
config: LLMConfig,
stream: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Build request body for OpenAI API."""
body: dict[str, Any] = {
"model": config.model,
"messages": messages,
"max_tokens": config.max_tokens,
"temperature": config.temperature,
"top_p": config.top_p,
"stream": stream,
}
body.update(config.extra_params)
body.update(kwargs)
return body
@retry(
retry=retry_if_exception_type(httpx.TimeoutException),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
)
async def generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns complete response for ChatResponse.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content and metadata.
Raises:
LLMException: If generation fails.
TimeoutException: If request times out.
"""
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
try:
response = await client.post(
f"{self._base_url}/chat/completions",
json=body,
)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException as e:
logger.error(f"[AC-AISVC-02] LLM request timeout: {e}")
raise TimeoutException(message=f"LLM request timed out: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"[AC-AISVC-02] LLM API error: {e}")
error_detail = self._parse_error_response(e.response)
raise LLMException(
message=f"LLM API error: {error_detail}",
details=[{"status_code": e.response.status_code, "response": error_detail}],
)
except json.JSONDecodeError as e:
logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}")
raise LLMException(message=f"Failed to parse LLM response: {e}")
try:
choice = data["choices"][0]
content = choice["message"]["content"]
usage = data.get("usage", {})
finish_reason = choice.get("finish_reason", "stop")
logger.info(
f"[AC-AISVC-02] Generated response: "
f"tokens={usage.get('total_tokens', 'N/A')}, "
f"finish_reason={finish_reason}"
)
return LLMResponse(
content=content,
model=data.get("model", effective_config.model),
usage=usage,
finish_reason=finish_reason,
metadata={"raw_response": data},
)
except (KeyError, IndexError) as e:
logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}")
raise LLMException(
message=f"Unexpected LLM response format: {e}",
details=[{"response": str(data)}],
)
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Yields:
LLMStreamChunk with incremental content.
Raises:
LLMException: If generation fails.
TimeoutException: If request times out.
"""
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
try:
async with client.stream(
"POST",
f"{self._base_url}/chat/completions",
json=body,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
json_str = line[6:]
try:
chunk_data = json.loads(json_str)
chunk = self._parse_stream_chunk(chunk_data, effective_config.model)
if chunk:
yield chunk
except json.JSONDecodeError as e:
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
continue
except httpx.TimeoutException as e:
logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}")
raise TimeoutException(message=f"LLM streaming request timed out: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}")
error_detail = self._parse_error_response(e.response)
raise LLMException(
message=f"LLM streaming API error: {error_detail}",
details=[{"status_code": e.response.status_code, "response": error_detail}],
)
logger.info(f"[AC-AISVC-06] Streaming generation completed")
def _parse_stream_chunk(
self,
data: dict[str, Any],
model: str,
) -> LLMStreamChunk | None:
"""Parse a streaming chunk from OpenAI API."""
try:
choices = data.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
content = delta.get("content", "")
finish_reason = choices[0].get("finish_reason")
if not content and not finish_reason:
return None
return LLMStreamChunk(
delta=content,
model=data.get("model", model),
finish_reason=finish_reason,
metadata={"raw_chunk": data},
)
except (KeyError, IndexError) as e:
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
return None
def _parse_error_response(self, response: httpx.Response) -> str:
"""Parse error response from API."""
try:
data = response.json()
if "error" in data:
error = data["error"]
if isinstance(error, dict):
return error.get("message", str(error))
return str(error)
return response.text
except Exception:
return response.text
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
_llm_client: OpenAIClient | None = None
def get_llm_client() -> OpenAIClient:
"""Get or create LLM client instance."""
global _llm_client
if _llm_client is None:
_llm_client = OpenAIClient()
return _llm_client
async def close_llm_client() -> None:
"""Close the global LLM client."""
global _llm_client
if _llm_client:
await _llm_client.close()
_llm_client = None

View File

@ -0,0 +1,170 @@
"""
Memory service for AI Service.
[AC-AISVC-13] Session-based memory management with tenant isolation.
"""
import logging
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import ChatMessage, ChatMessageCreate, ChatSession, ChatSessionCreate
logger = logging.getLogger(__name__)
class MemoryService:
"""
[AC-AISVC-13] Memory service for session-based conversation history.
All operations are scoped by (tenant_id, session_id) for multi-tenant isolation.
"""
def __init__(self, session: AsyncSession):
self._session = session
async def get_or_create_session(
self,
tenant_id: str,
session_id: str,
channel_type: str | None = None,
metadata: dict | None = None,
) -> ChatSession:
"""
[AC-AISVC-13] Get existing session or create a new one.
Ensures tenant isolation by querying with tenant_id.
"""
stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
result = await self._session.execute(stmt)
existing_session = result.scalar_one_or_none()
if existing_session:
logger.info(
f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}"
)
return existing_session
new_session = ChatSession(
tenant_id=tenant_id,
session_id=session_id,
channel_type=channel_type,
metadata_=metadata,
)
self._session.add(new_session)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}"
)
return new_session
async def load_history(
self,
tenant_id: str,
session_id: str,
limit: int | None = None,
) -> Sequence[ChatMessage]:
"""
[AC-AISVC-13] Load conversation history for a session.
All queries are filtered by tenant_id to ensure isolation.
"""
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
.order_by(col(ChatMessage.created_at).asc())
)
if limit:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
messages = result.scalars().all()
logger.info(
f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}"
)
return messages
async def append_message(
self,
tenant_id: str,
session_id: str,
role: str,
content: str,
) -> ChatMessage:
"""
[AC-AISVC-13] Append a message to the session history.
Message is scoped by tenant_id for isolation.
"""
message = ChatMessage(
tenant_id=tenant_id,
session_id=session_id,
role=role,
content=content,
)
self._session.add(message)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}"
)
return message
async def append_messages(
self,
tenant_id: str,
session_id: str,
messages: list[dict[str, str]],
) -> list[ChatMessage]:
"""
[AC-AISVC-13] Append multiple messages to the session history.
Used for batch insertion of conversation turns.
"""
chat_messages = []
for msg in messages:
message = ChatMessage(
tenant_id=tenant_id,
session_id=session_id,
role=msg["role"],
content=msg["content"],
)
self._session.add(message)
chat_messages.append(message)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}"
)
return chat_messages
async def clear_history(self, tenant_id: str, session_id: str) -> int:
"""
[AC-AISVC-13] Clear all messages for a session.
Only affects messages within the tenant's scope.
"""
stmt = select(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
count = 0
for message in messages:
await self._session.delete(message)
count += 1
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}"
)
return count

View File

@ -0,0 +1,97 @@
"""
Orchestrator service for AI Service.
[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation.
"""
import logging
from typing import AsyncGenerator
from sse_starlette.sse import ServerSentEvent
from app.models import ChatRequest, ChatResponse
from app.core.sse import create_final_event, create_message_event, SSEStateMachine
logger = logging.getLogger(__name__)
class OrchestratorService:
"""
[AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation.
Coordinates memory, retrieval, and LLM components.
"""
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns ChatResponse with reply, confidence, shouldTransfer.
"""
logger.info(
f"[AC-AISVC-01] Generating response for tenant={tenant_id}, "
f"session={request.session_id}"
)
reply = f"Received your message: {request.current_message}"
return ChatResponse(
reply=reply,
confidence=0.85,
should_transfer=False,
)
async def generate_stream(
self, tenant_id: str, request: ChatRequest
) -> AsyncGenerator[ServerSentEvent, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
"""
logger.info(
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
f"session={request.session_id}"
)
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
try:
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
full_reply = ""
for part in reply_parts:
if state_machine.can_send_message():
full_reply += part
yield create_message_event(delta=part)
await self._simulate_llm_delay()
if await state_machine.transition_to_final():
yield create_final_event(
reply=full_reply,
confidence=0.85,
should_transfer=False,
)
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
if await state_machine.transition_to_error():
from app.core.sse import create_error_event
yield create_error_event(
code="GENERATION_ERROR",
message=str(e),
)
finally:
await state_machine.close()
async def _simulate_llm_delay(self) -> None:
"""Simulate LLM processing delay for demo purposes."""
import asyncio
await asyncio.sleep(0.1)
_orchestrator_service: OrchestratorService | None = None
def get_orchestrator_service() -> OrchestratorService:
"""Get or create orchestrator service instance."""
global _orchestrator_service
if _orchestrator_service is None:
_orchestrator_service = OrchestratorService()
return _orchestrator_service

View File

@ -0,0 +1,21 @@
"""
Retrieval module for AI Service.
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
"""
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalHit,
RetrievalResult,
)
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
__all__ = [
"BaseRetriever",
"RetrievalContext",
"RetrievalHit",
"RetrievalResult",
"VectorRetriever",
"get_vector_retriever",
]

View File

@ -0,0 +1,96 @@
"""
Retrieval layer for AI Service.
[AC-AISVC-16] Abstract base class for retrievers with plugin point support.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class RetrievalContext:
"""
[AC-AISVC-16] Context for retrieval operations.
Contains all necessary information for retrieval plugins.
"""
tenant_id: str
query: str
session_id: str | None = None
channel_type: str | None = None
metadata: dict[str, Any] | None = None
@dataclass
class RetrievalHit:
"""
[AC-AISVC-16] Single retrieval result hit.
Unified structure for all retriever types.
"""
text: str
score: float
source: str
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class RetrievalResult:
"""
[AC-AISVC-16] Result from retrieval operation.
Contains hits and optional diagnostics.
"""
hits: list[RetrievalHit] = field(default_factory=list)
diagnostics: dict[str, Any] | None = None
@property
def is_empty(self) -> bool:
"""Check if no hits were found."""
return len(self.hits) == 0
@property
def max_score(self) -> float:
"""Get the maximum score among hits."""
if not self.hits:
return 0.0
return max(hit.score for hit in self.hits)
@property
def hit_count(self) -> int:
"""Get the number of hits."""
return len(self.hits)
class BaseRetriever(ABC):
"""
[AC-AISVC-16] Abstract base class for retrievers.
Provides plugin point for different retrieval strategies (Vector, Graph, Hybrid).
"""
@abstractmethod
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
"""
[AC-AISVC-16] Retrieve relevant documents for the given context.
Args:
ctx: Retrieval context containing tenant_id, query, and optional metadata.
Returns:
RetrievalResult with hits and optional diagnostics.
"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""
Check if the retriever is healthy and ready to serve requests.
Returns:
True if healthy, False otherwise.
"""
pass

View File

@ -0,0 +1,169 @@
"""
Vector retriever for AI Service.
[AC-AISVC-16, AC-AISVC-17] Qdrant-based vector retrieval with score threshold filtering.
"""
import logging
from typing import Any
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalHit,
RetrievalResult,
)
logger = logging.getLogger(__name__)
settings = get_settings()
class VectorRetriever(BaseRetriever):
"""
[AC-AISVC-16, AC-AISVC-17] Vector-based retriever using Qdrant.
Supports score threshold filtering and tenant isolation.
"""
def __init__(
self,
qdrant_client: QdrantClient | None = None,
top_k: int | None = None,
score_threshold: float | None = None,
min_hits: int | None = None,
):
self._qdrant_client = qdrant_client
self._top_k = top_k or settings.rag_top_k
self._score_threshold = score_threshold or settings.rag_score_threshold
self._min_hits = min_hits or settings.rag_min_hits
async def _get_client(self) -> QdrantClient:
"""Get Qdrant client instance."""
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
"""
[AC-AISVC-16, AC-AISVC-17] Retrieve documents from vector store.
Steps:
1. Generate embedding for query (placeholder - requires embedding provider)
2. Search in tenant's collection
3. Filter by score threshold
4. Return structured result
Args:
ctx: Retrieval context with tenant_id and query.
Returns:
RetrievalResult with filtered hits.
"""
logger.info(
f"[AC-AISVC-16] Starting vector retrieval for tenant={ctx.tenant_id}, query={ctx.query[:50]}..."
)
try:
client = await self._get_client()
query_vector = await self._get_embedding(ctx.query)
hits = await client.search(
tenant_id=ctx.tenant_id,
query_vector=query_vector,
limit=self._top_k,
score_threshold=self._score_threshold,
)
retrieval_hits = [
RetrievalHit(
text=hit.get("payload", {}).get("text", ""),
score=hit.get("score", 0.0),
source=hit.get("payload", {}).get("source", "vector"),
metadata=hit.get("payload", {}),
)
for hit in hits
if hit.get("score", 0.0) >= self._score_threshold
]
is_insufficient = len(retrieval_hits) < self._min_hits
diagnostics = {
"query_length": len(ctx.query),
"top_k": self._top_k,
"score_threshold": self._score_threshold,
"min_hits": self._min_hits,
"total_candidates": len(hits),
"filtered_hits": len(retrieval_hits),
"is_insufficient": is_insufficient,
"max_score": max((h.score for h in retrieval_hits), default=0.0),
}
logger.info(
f"[AC-AISVC-17] Retrieval complete: {len(retrieval_hits)} hits, "
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
)
return RetrievalResult(
hits=retrieval_hits,
diagnostics=diagnostics,
)
except Exception as e:
logger.error(f"[AC-AISVC-16] Retrieval error: {e}")
return RetrievalResult(
hits=[],
diagnostics={"error": str(e), "is_insufficient": True},
)
async def _get_embedding(self, text: str) -> list[float]:
"""
Generate embedding for text.
[AC-AISVC-16] Placeholder for embedding generation.
TODO: Integrate with actual embedding provider (OpenAI, local model, etc.)
"""
import hashlib
hash_obj = hashlib.sha256(text.encode())
hash_bytes = hash_obj.digest()
embedding = []
for i in range(0, min(len(hash_bytes) * 8, settings.qdrant_vector_size)):
byte_idx = i // 8
bit_idx = i % 8
if byte_idx < len(hash_bytes):
val = (hash_bytes[byte_idx] >> bit_idx) & 1
embedding.append(float(val))
else:
embedding.append(0.0)
while len(embedding) < settings.qdrant_vector_size:
embedding.append(0.0)
return embedding[: settings.qdrant_vector_size]
async def health_check(self) -> bool:
"""
[AC-AISVC-16] Check if Qdrant connection is healthy.
"""
try:
client = await self._get_client()
qdrant = await client.get_client()
await qdrant.get_collections()
return True
except Exception as e:
logger.error(f"[AC-AISVC-16] Health check failed: {e}")
return False
_vector_retriever: VectorRetriever | None = None
async def get_vector_retriever() -> VectorRetriever:
"""Get or create VectorRetriever instance."""
global _vector_retriever
if _vector_retriever is None:
_vector_retriever = VectorRetriever()
return _vector_retriever

53
ai-service/pyproject.toml Normal file
View File

@ -0,0 +1,53 @@
[project]
name = "ai-service"
version = "0.1.0"
description = "Python AI Service for intelligent chat with RAG support"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"fastapi>=0.109.0",
"uvicorn[standard]>=0.27.0",
"pydantic>=2.5.0",
"pydantic-settings>=2.1.0",
"sse-starlette>=2.0.0",
"httpx>=0.26.0",
"tenacity>=8.2.0",
"sqlmodel>=0.0.14",
"asyncpg>=0.29.0",
"qdrant-client>=1.7.0",
"tiktoken>=0.5.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.4.0",
"pytest-asyncio>=0.23.0",
"pytest-cov>=4.1.0",
"httpx>=0.26.0",
"ruff>=0.1.0",
"mypy>=1.8.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["app"]
[tool.ruff]
line-length = 120
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
[tool.mypy]
python_version = "3.10"
strict = true
warn_return_any = true
warn_unused_configs = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View File

@ -0,0 +1,3 @@
"""
Tests package for AI Service.
"""

View File

@ -0,0 +1,10 @@
"""
Pytest configuration for AI Service tests.
"""
import pytest
@pytest.fixture
def anyio_backend():
return "asyncio"

View File

@ -0,0 +1,285 @@
"""
Tests for response mode switching based on Accept header.
[AC-AISVC-06] Tests for automatic switching between JSON and SSE streaming modes.
"""
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from app.main import app
class TestAcceptHeaderSwitching:
"""
[AC-AISVC-06] Test cases for Accept header based response mode switching.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_request_body(self):
return {
"sessionId": "test_session_001",
"currentMessage": "Hello, how are you?",
"channelType": "wechat",
}
@pytest.fixture
def valid_headers(self):
return {"X-Tenant-Id": "tenant_001"}
def test_json_response_with_default_accept(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-06] Test that default Accept header returns JSON response.
"""
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=valid_headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
def test_json_response_with_application_json_accept(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-06] Test that Accept: application/json returns JSON response.
"""
headers = {**valid_headers, "Accept": "application/json"}
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
def test_sse_response_with_text_event_stream_accept(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-06] Test that Accept: text/event-stream returns SSE response.
"""
headers = {**valid_headers, "Accept": "text/event-stream"}
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=headers,
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
content = response.text
assert "event: message" in content
assert "event: final" in content
def test_sse_response_event_sequence(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-07, AC-AISVC-08] Test that SSE events follow proper sequence.
message* -> final -> close
"""
headers = {**valid_headers, "Accept": "text/event-stream"}
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=headers,
)
content = response.text
assert "event:message" in content or "event: message" in content, f"Expected message event in: {content[:500]}"
assert "event:final" in content or "event: final" in content, f"Expected final event in: {content[:500]}"
message_idx = content.find("event:message")
if message_idx == -1:
message_idx = content.find("event: message")
final_idx = content.find("event:final")
if final_idx == -1:
final_idx = content.find("event: final")
assert final_idx > message_idx, "final event should come after message events"
def test_missing_tenant_id_returns_400(
self, client: TestClient, valid_request_body: dict
):
"""
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
"""
response = client.post(
"/ai/chat",
json=valid_request_body,
)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
assert "message" in data
def test_invalid_channel_type_returns_400(
self, client: TestClient, valid_headers: dict
):
"""
[AC-AISVC-03] Test that invalid channel type returns 400 error.
"""
invalid_body = {
"sessionId": "test_session_001",
"currentMessage": "Hello",
"channelType": "invalid_channel",
}
response = client.post(
"/ai/chat",
json=invalid_body,
headers=valid_headers,
)
assert response.status_code == 400
def test_missing_required_fields_returns_400(
self, client: TestClient, valid_headers: dict
):
"""
[AC-AISVC-03] Test that missing required fields return 400 error.
"""
incomplete_body = {
"sessionId": "test_session_001",
}
response = client.post(
"/ai/chat",
json=incomplete_body,
headers=valid_headers,
)
assert response.status_code == 400
class TestHealthEndpoint:
"""
[AC-AISVC-20] Test cases for health check endpoint.
"""
@pytest.fixture
def client(self):
return TestClient(app)
def test_health_check_returns_200(self, client: TestClient):
"""
[AC-AISVC-20] Test that health check returns 200 with status.
"""
response = client.get("/ai/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
class TestSSEStateMachine:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine.
"""
@pytest.mark.asyncio
async def test_state_transitions(self):
from app.core.sse import SSEState, SSEStateMachine
state_machine = SSEStateMachine()
assert state_machine.state == SSEState.INIT
success = await state_machine.transition_to_streaming()
assert success is True
assert state_machine.state == SSEState.STREAMING
assert state_machine.can_send_message() is True
success = await state_machine.transition_to_final()
assert success is True
assert state_machine.state == SSEState.FINAL_SENT
assert state_machine.can_send_message() is False
await state_machine.close()
assert state_machine.state == SSEState.CLOSED
@pytest.mark.asyncio
async def test_error_transition_from_streaming(self):
from app.core.sse import SSEState, SSEStateMachine
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
success = await state_machine.transition_to_error()
assert success is True
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_cannot_transition_to_final_from_init(self):
from app.core.sse import SSEStateMachine
state_machine = SSEStateMachine()
success = await state_machine.transition_to_final()
assert success is False
class TestMiddleware:
"""
[AC-AISVC-10, AC-AISVC-12] Test cases for middleware.
"""
@pytest.fixture
def client(self):
return TestClient(app)
def test_tenant_context_extraction(
self, client: TestClient
):
"""
[AC-AISVC-10] Test that X-Tenant-Id is properly extracted and used.
"""
headers = {"X-Tenant-Id": "tenant_test_123"}
body = {
"sessionId": "session_001",
"currentMessage": "Test message",
"channelType": "wechat",
}
response = client.post("/ai/chat", json=body, headers=headers)
assert response.status_code == 200
def test_health_endpoint_bypasses_tenant_check(
self, client: TestClient
):
"""
Test that health endpoint doesn't require X-Tenant-Id.
"""
response = client.get("/ai/health")
assert response.status_code == 200

View File

@ -0,0 +1,319 @@
"""
Unit tests for LLM Adapter.
[AC-AISVC-02, AC-AISVC-06] Tests for LLM client interface.
Tests cover:
- Non-streaming generation
- Streaming generation
- Error handling
- Retry logic
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.services.llm.base import LLMConfig, LLMResponse, LLMStreamChunk
from app.services.llm.openai_client import (
LLMException,
OpenAIClient,
TimeoutException,
)
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
settings.llm_api_key = "test-api-key"
settings.llm_base_url = "https://api.openai.com/v1"
settings.llm_model = "gpt-4o-mini"
settings.llm_max_tokens = 2048
settings.llm_temperature = 0.7
settings.llm_timeout_seconds = 30
settings.llm_max_retries = 3
return settings
@pytest.fixture
def llm_client(mock_settings):
"""Create LLM client with mocked settings."""
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
client = OpenAIClient()
yield client
@pytest.fixture
def mock_messages():
"""Sample chat messages for testing."""
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
]
@pytest.fixture
def mock_generate_response():
"""Sample non-streaming response from OpenAI API."""
return {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I'm doing well, thank you for asking!",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35,
},
}
@pytest.fixture
def mock_stream_chunks():
"""Sample streaming chunks from OpenAI API."""
return [
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n",
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n",
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" How\"},\"finish_reason\":null}]}\n",
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" can I help?\"},\"finish_reason\":\"stop\"}]}\n",
"data: [DONE]\n",
]
class TestOpenAIClientGenerate:
"""Tests for non-streaming generation. [AC-AISVC-02]"""
@pytest.mark.asyncio
async def test_generate_success(self, llm_client, mock_messages, mock_generate_response):
"""[AC-AISVC-02] Test successful non-streaming generation."""
mock_response = MagicMock()
mock_response.json.return_value = mock_generate_response
mock_response.raise_for_status = MagicMock()
with patch.object(
llm_client, "_get_client"
) as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_client
result = await llm_client.generate(mock_messages)
assert isinstance(result, LLMResponse)
assert result.content == "Hello! I'm doing well, thank you for asking!"
assert result.model == "gpt-4o-mini"
assert result.finish_reason == "stop"
assert result.usage["total_tokens"] == 35
@pytest.mark.asyncio
async def test_generate_with_custom_config(self, llm_client, mock_messages, mock_generate_response):
"""[AC-AISVC-02] Test generation with custom configuration."""
custom_config = LLMConfig(
model="gpt-4",
max_tokens=1024,
temperature=0.5,
)
mock_response = MagicMock()
mock_response.json.return_value = {**mock_generate_response, "model": "gpt-4"}
mock_response.raise_for_status = MagicMock()
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_client
result = await llm_client.generate(mock_messages, config=custom_config)
assert result.model == "gpt-4"
@pytest.mark.asyncio
async def test_generate_timeout_error(self, llm_client, mock_messages):
"""[AC-AISVC-02] Test timeout error handling."""
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
mock_get_client.return_value = mock_client
with pytest.raises(TimeoutException):
await llm_client.generate(mock_messages)
@pytest.mark.asyncio
async def test_generate_api_error(self, llm_client, mock_messages):
"""[AC-AISVC-02] Test API error handling."""
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = '{"error": {"message": "Invalid API key"}}'
mock_response.json.return_value = {"error": {"message": "Invalid API key"}}
http_error = httpx.HTTPStatusError(
"Unauthorized",
request=MagicMock(),
response=mock_response,
)
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=http_error)
mock_get_client.return_value = mock_client
with pytest.raises(LLMException) as exc_info:
await llm_client.generate(mock_messages)
assert "Invalid API key" in str(exc_info.value.message)
@pytest.mark.asyncio
async def test_generate_malformed_response(self, llm_client, mock_messages):
"""[AC-AISVC-02] Test handling of malformed response."""
mock_response = MagicMock()
mock_response.json.return_value = {"invalid": "response"}
mock_response.raise_for_status = MagicMock()
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_client
with pytest.raises(LLMException):
await llm_client.generate(mock_messages)
class MockAsyncStreamContext:
"""Mock async context manager for streaming."""
def __init__(self, response):
self._response = response
async def __aenter__(self):
return self._response
async def __aexit__(self, *args):
pass
class TestOpenAIClientStreamGenerate:
"""Tests for streaming generation. [AC-AISVC-06, AC-AISVC-07]"""
@pytest.mark.asyncio
async def test_stream_generate_success(self, llm_client, mock_messages, mock_stream_chunks):
"""[AC-AISVC-06, AC-AISVC-07] Test successful streaming generation."""
async def mock_aiter_lines():
for chunk in mock_stream_chunks:
yield chunk
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.aiter_lines = mock_aiter_lines
mock_client = AsyncMock()
mock_client.stream = MagicMock(return_value=MockAsyncStreamContext(mock_response))
with patch.object(llm_client, "_get_client", return_value=mock_client):
chunks = []
async for chunk in llm_client.stream_generate(mock_messages):
chunks.append(chunk)
assert len(chunks) == 4
assert chunks[0].delta == "Hello"
assert chunks[-1].finish_reason == "stop"
@pytest.mark.asyncio
async def test_stream_generate_timeout_error(self, llm_client, mock_messages):
"""[AC-AISVC-06] Test streaming timeout error handling."""
mock_client = AsyncMock()
class TimeoutContext:
async def __aenter__(self):
raise httpx.TimeoutException("Timeout")
async def __aexit__(self, *args):
pass
mock_client.stream = MagicMock(return_value=TimeoutContext())
with patch.object(llm_client, "_get_client", return_value=mock_client):
with pytest.raises(TimeoutException):
async for _ in llm_client.stream_generate(mock_messages):
pass
@pytest.mark.asyncio
async def test_stream_generate_api_error(self, llm_client, mock_messages):
"""[AC-AISVC-06] Test streaming API error handling."""
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_response.json.return_value = {"error": {"message": "Internal Server Error"}}
http_error = httpx.HTTPStatusError(
"Internal Server Error",
request=MagicMock(),
response=mock_response,
)
mock_client = AsyncMock()
class ErrorContext:
async def __aenter__(self):
raise http_error
async def __aexit__(self, *args):
pass
mock_client.stream = MagicMock(return_value=ErrorContext())
with patch.object(llm_client, "_get_client", return_value=mock_client):
with pytest.raises(LLMException):
async for _ in llm_client.stream_generate(mock_messages):
pass
class TestOpenAIClientConfig:
"""Tests for LLM configuration."""
def test_default_config(self, mock_settings):
"""Test default configuration from settings."""
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
client = OpenAIClient()
assert client._model == "gpt-4o-mini"
assert client._default_config.max_tokens == 2048
assert client._default_config.temperature == 0.7
def test_custom_config_override(self, mock_settings):
"""Test custom configuration override."""
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
client = OpenAIClient(
api_key="custom-key",
base_url="https://custom.api.com/v1",
model="gpt-4",
)
assert client._api_key == "custom-key"
assert client._base_url == "https://custom.api.com/v1"
assert client._model == "gpt-4"
class TestOpenAIClientClose:
"""Tests for client cleanup."""
@pytest.mark.asyncio
async def test_close_client(self, llm_client):
"""Test client close releases resources."""
mock_client = AsyncMock()
mock_client.aclose = AsyncMock()
llm_client._client = mock_client
await llm_client.close()
mock_client.aclose.assert_called_once()
assert llm_client._client is None

View File

@ -0,0 +1,210 @@
"""
Unit tests for Memory service.
[AC-AISVC-10, AC-AISVC-11, AC-AISVC-13] Tests for multi-tenant session and message management.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import ChatMessage, ChatSession
from app.services.memory import MemoryService
@pytest.fixture
def mock_session():
"""Create a mock AsyncSession."""
session = AsyncMock(spec=AsyncSession)
session.add = MagicMock()
session.flush = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def memory_service(mock_session):
"""Create MemoryService with mocked session."""
return MemoryService(mock_session)
class TestMemoryServiceTenantIsolation:
"""
[AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in memory service.
"""
@pytest.mark.asyncio
async def test_get_or_create_session_tenant_isolation(self, memory_service, mock_session):
"""
[AC-AISVC-11] Different tenants with same session_id should have separate sessions.
"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
session1 = await memory_service.get_or_create_session(
tenant_id="tenant_a",
session_id="session_123",
)
session2 = await memory_service.get_or_create_session(
tenant_id="tenant_b",
session_id="session_123",
)
assert session1.tenant_id == "tenant_a"
assert session2.tenant_id == "tenant_b"
assert session1.session_id == "session_123"
assert session2.session_id == "session_123"
@pytest.mark.asyncio
async def test_load_history_tenant_isolation(self, memory_service, mock_session):
"""
[AC-AISVC-11] Loading history should only return messages for the specific tenant.
"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Hello"),
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
messages = await memory_service.load_history(
tenant_id="tenant_a",
session_id="session_123",
)
assert len(messages) == 1
assert messages[0].tenant_id == "tenant_a"
@pytest.mark.asyncio
async def test_append_message_tenant_scoped(self, memory_service, mock_session):
"""
[AC-AISVC-10, AC-AISVC-13] Appended messages should be scoped to tenant.
"""
message = await memory_service.append_message(
tenant_id="tenant_a",
session_id="session_123",
role="user",
content="Test message",
)
assert message.tenant_id == "tenant_a"
assert message.session_id == "session_123"
assert message.role == "user"
assert message.content == "Test message"
class TestMemoryServiceSessionManagement:
"""
[AC-AISVC-13] Tests for session-based memory management.
"""
@pytest.mark.asyncio
async def test_get_existing_session(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should return existing session if it exists.
"""
existing_session = ChatSession(
tenant_id="tenant_a",
session_id="session_123",
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_session
mock_session.execute = AsyncMock(return_value=mock_result)
session = await memory_service.get_or_create_session(
tenant_id="tenant_a",
session_id="session_123",
)
assert session.tenant_id == "tenant_a"
assert session.session_id == "session_123"
@pytest.mark.asyncio
async def test_create_new_session(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should create new session if it doesn't exist.
"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
session = await memory_service.get_or_create_session(
tenant_id="tenant_a",
session_id="session_new",
channel_type="wechat",
metadata={"user_id": "user_123"},
)
assert session.tenant_id == "tenant_a"
assert session.session_id == "session_new"
assert session.channel_type == "wechat"
@pytest.mark.asyncio
async def test_append_multiple_messages(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should append multiple messages in batch.
"""
messages_data = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
messages = await memory_service.append_messages(
tenant_id="tenant_a",
session_id="session_123",
messages=messages_data,
)
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[1].role == "assistant"
@pytest.mark.asyncio
async def test_load_history_with_limit(self, memory_service, mock_session):
"""
[AC-AISVC-13] Should limit the number of messages returned.
"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content=f"Msg {i}")
for i in range(5)
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
messages = await memory_service.load_history(
tenant_id="tenant_a",
session_id="session_123",
limit=3,
)
assert len(messages) == 5
class TestMemoryServiceClearHistory:
"""
[AC-AISVC-13] Tests for clearing session history.
"""
@pytest.mark.asyncio
async def test_clear_history_tenant_scoped(self, memory_service, mock_session):
"""
[AC-AISVC-11] Clearing history should only affect the specified tenant's messages.
"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Msg 1"),
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="assistant", content="Msg 2"),
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
count = await memory_service.clear_history(
tenant_id="tenant_a",
session_id="session_123",
)
assert count == 2