[AC-AISVC-02, AC-AISVC-16] 多个需求合并 #1
|
|
@ -101,7 +101,10 @@ def create_final_event(
|
|||
transfer_reason=transfer_reason,
|
||||
metadata=metadata,
|
||||
)
|
||||
return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True))
|
||||
return format_sse_event(
|
||||
SSEEventType.FINAL,
|
||||
event_data.model_dump(exclude_none=True, by_alias=True)
|
||||
)
|
||||
|
||||
|
||||
def create_error_event(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
Context management utilities for AI Service.
|
||||
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
|
||||
|
||||
Design reference: design.md Section 7 - 上下文合并规则
|
||||
- H_local: Memory layer history (sorted by time)
|
||||
- H_ext: External history from Java request (in passed order)
|
||||
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
|
||||
- Truncation: Keep most recent N messages within token budget
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models import ChatMessage, Role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergedContext:
|
||||
"""
|
||||
Result of context merging.
|
||||
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
|
||||
"""
|
||||
messages: list[dict[str, str]] = field(default_factory=list)
|
||||
total_tokens: int = 0
|
||||
local_count: int = 0
|
||||
external_count: int = 0
|
||||
duplicates_skipped: int = 0
|
||||
truncated_count: int = 0
|
||||
diagnostics: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ContextMerger:
|
||||
"""
|
||||
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
|
||||
|
||||
Design reference: design.md Section 7
|
||||
- Deduplication based on message fingerprint
|
||||
- Priority: local history takes precedence
|
||||
- Token-based truncation using tiktoken
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_history_tokens: int | None = None,
|
||||
encoding_name: str = "cl100k_base",
|
||||
):
|
||||
settings = get_settings()
|
||||
self._max_history_tokens = max_history_tokens or 4096
|
||||
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def compute_fingerprint(self, role: str, content: str) -> str:
|
||||
"""
|
||||
Compute message fingerprint for deduplication.
|
||||
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant)
|
||||
content: Message content
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the normalized message
|
||||
"""
|
||||
normalized_content = content.strip()
|
||||
fingerprint_input = f"{role}|{normalized_content}"
|
||||
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
|
||||
|
||||
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
|
||||
"""Convert ChatMessage or dict to standard dict format."""
|
||||
if isinstance(message, ChatMessage):
|
||||
return {"role": message.role.value, "content": message.content}
|
||||
return message
|
||||
|
||||
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
|
||||
"""
|
||||
Count total tokens in messages using tiktoken.
|
||||
[AC-AISVC-14] Token counting for history truncation.
|
||||
"""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
total += len(self._encoding.encode(msg.get("role", "")))
|
||||
total += len(self._encoding.encode(msg.get("content", "")))
|
||||
total += 4 # Approximate overhead for message structure
|
||||
return total
|
||||
|
||||
def merge_context(
|
||||
self,
|
||||
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
) -> MergedContext:
|
||||
"""
|
||||
Merge local and external history with deduplication.
|
||||
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
|
||||
|
||||
Design reference: design.md Section 7.2
|
||||
1. Build seen set from H_local
|
||||
2. Traverse H_ext, append if fingerprint not seen
|
||||
3. Local history takes priority
|
||||
|
||||
Args:
|
||||
local_history: History from Memory layer (H_local)
|
||||
external_history: History from Java request (H_ext)
|
||||
|
||||
Returns:
|
||||
MergedContext with merged messages and diagnostics
|
||||
"""
|
||||
result = MergedContext()
|
||||
seen_fingerprints: set[str] = set()
|
||||
merged_messages: list[dict[str, str]] = []
|
||||
diagnostics: list[dict[str, Any]] = []
|
||||
|
||||
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
|
||||
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
|
||||
|
||||
for msg in local_messages:
|
||||
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||||
seen_fingerprints.add(fingerprint)
|
||||
merged_messages.append(msg)
|
||||
result.local_count += 1
|
||||
|
||||
for msg in external_messages:
|
||||
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||||
if fingerprint not in seen_fingerprints:
|
||||
seen_fingerprints.add(fingerprint)
|
||||
merged_messages.append(msg)
|
||||
result.external_count += 1
|
||||
else:
|
||||
result.duplicates_skipped += 1
|
||||
diagnostics.append({
|
||||
"type": "duplicate_skipped",
|
||||
"role": msg["role"],
|
||||
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
|
||||
})
|
||||
|
||||
result.messages = merged_messages
|
||||
result.diagnostics = diagnostics
|
||||
result.total_tokens = self._count_tokens(merged_messages)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
|
||||
f"local={result.local_count}, external={result.external_count}, "
|
||||
f"duplicates_skipped={result.duplicates_skipped}, "
|
||||
f"total_tokens={result.total_tokens}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def truncate_context(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[list[dict[str, str]], int]:
|
||||
"""
|
||||
Truncate context to fit within token budget.
|
||||
[AC-AISVC-14] Keep most recent N messages within budget.
|
||||
|
||||
Design reference: design.md Section 7.4
|
||||
- Budget = maxHistoryTokens (configurable)
|
||||
- Strategy: Keep most recent messages (from tail backward)
|
||||
|
||||
Args:
|
||||
messages: List of messages to truncate
|
||||
max_tokens: Maximum token budget (uses default if not provided)
|
||||
|
||||
Returns:
|
||||
Tuple of (truncated messages, truncated count)
|
||||
"""
|
||||
budget = max_tokens or self._max_history_tokens
|
||||
if not messages:
|
||||
return [], 0
|
||||
|
||||
total_tokens = self._count_tokens(messages)
|
||||
if total_tokens <= budget:
|
||||
return messages, 0
|
||||
|
||||
truncated_messages: list[dict[str, str]] = []
|
||||
current_tokens = 0
|
||||
truncated_count = 0
|
||||
|
||||
for msg in reversed(messages):
|
||||
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
|
||||
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
|
||||
msg_tokens += 4
|
||||
|
||||
if current_tokens + msg_tokens <= budget:
|
||||
truncated_messages.insert(0, msg)
|
||||
current_tokens += msg_tokens
|
||||
else:
|
||||
truncated_count += 1
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-14] Context truncated: "
|
||||
f"original={len(messages)}, truncated={len(truncated_messages)}, "
|
||||
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
|
||||
)
|
||||
|
||||
return truncated_messages, truncated_count
|
||||
|
||||
def merge_and_truncate(
|
||||
self,
|
||||
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
max_tokens: int | None = None,
|
||||
) -> MergedContext:
|
||||
"""
|
||||
Merge and truncate context in one operation.
|
||||
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
|
||||
|
||||
Args:
|
||||
local_history: History from Memory layer (H_local)
|
||||
external_history: History from Java request (H_ext)
|
||||
max_tokens: Maximum token budget
|
||||
|
||||
Returns:
|
||||
MergedContext with final messages after merge and truncate
|
||||
"""
|
||||
merged = self.merge_context(local_history, external_history)
|
||||
|
||||
truncated_messages, truncated_count = self.truncate_context(
|
||||
merged.messages, max_tokens
|
||||
)
|
||||
|
||||
merged.messages = truncated_messages
|
||||
merged.truncated_count = truncated_count
|
||||
merged.total_tokens = self._count_tokens(truncated_messages)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
_context_merger: ContextMerger | None = None
|
||||
|
||||
|
||||
def get_context_merger() -> ContextMerger:
|
||||
"""Get or create context merger instance."""
|
||||
global _context_merger
|
||||
if _context_merger is None:
|
||||
_context_merger = ContextMerger()
|
||||
return _context_merger
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Orchestrator service for AI Service.
|
||||
[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation.
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -9,17 +9,31 @@ from typing import AsyncGenerator
|
|||
from sse_starlette.sse import ServerSentEvent
|
||||
|
||||
from app.models import ChatRequest, ChatResponse
|
||||
from app.core.sse import create_final_event, create_message_event, SSEStateMachine
|
||||
from app.core.sse import create_error_event, create_final_event, create_message_event, SSEStateMachine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrchestratorService:
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation.
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
|
||||
Coordinates memory, retrieval, and LLM components.
|
||||
|
||||
SSE Event Flow (per design.md Section 6.2):
|
||||
- message* (0 or more) -> final (exactly 1) -> close
|
||||
- OR message* (0 or more) -> error (exactly 1) -> close
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client=None):
|
||||
"""
|
||||
Initialize orchestrator with optional LLM client.
|
||||
|
||||
Args:
|
||||
llm_client: Optional LLM client for dependency injection.
|
||||
If None, will use mock implementation for demo.
|
||||
"""
|
||||
self._llm_client = llm_client
|
||||
|
||||
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
|
|
@ -30,6 +44,15 @@ class OrchestratorService:
|
|||
f"session={request.session_id}"
|
||||
)
|
||||
|
||||
if self._llm_client:
|
||||
messages = self._build_messages(request)
|
||||
response = await self._llm_client.generate(messages)
|
||||
return ChatResponse(
|
||||
reply=response.content,
|
||||
confidence=0.85,
|
||||
should_transfer=False,
|
||||
)
|
||||
|
||||
reply = f"Received your message: {request.current_message}"
|
||||
return ChatResponse(
|
||||
reply=reply,
|
||||
|
|
@ -43,6 +66,16 @@ class OrchestratorService:
|
|||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
||||
|
||||
SSE Event Sequence (per design.md Section 6.2):
|
||||
1. message events (multiple) - each with incremental delta
|
||||
2. final event (exactly 1) - with complete response
|
||||
3. connection close
|
||||
|
||||
OR on error:
|
||||
1. message events (0 or more)
|
||||
2. error event (exactly 1)
|
||||
3. connection close
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
||||
|
|
@ -53,14 +86,18 @@ class OrchestratorService:
|
|||
await state_machine.transition_to_streaming()
|
||||
|
||||
try:
|
||||
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
|
||||
full_reply = ""
|
||||
|
||||
for part in reply_parts:
|
||||
if state_machine.can_send_message():
|
||||
full_reply += part
|
||||
yield create_message_event(delta=part)
|
||||
await self._simulate_llm_delay()
|
||||
if self._llm_client:
|
||||
async for event in self._stream_from_llm(request, state_machine):
|
||||
if event.event == "message":
|
||||
full_reply += self._extract_delta_from_event(event)
|
||||
yield event
|
||||
else:
|
||||
async for event in self._stream_mock_response(request, state_machine):
|
||||
if event.event == "message":
|
||||
full_reply += self._extract_delta_from_event(event)
|
||||
yield event
|
||||
|
||||
if await state_machine.transition_to_final():
|
||||
yield create_final_event(
|
||||
|
|
@ -72,7 +109,6 @@ class OrchestratorService:
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
||||
if await state_machine.transition_to_error():
|
||||
from app.core.sse import create_error_event
|
||||
yield create_error_event(
|
||||
code="GENERATION_ERROR",
|
||||
message=str(e),
|
||||
|
|
@ -80,10 +116,73 @@ class OrchestratorService:
|
|||
finally:
|
||||
await state_machine.close()
|
||||
|
||||
async def _simulate_llm_delay(self) -> None:
|
||||
"""Simulate LLM processing delay for demo purposes."""
|
||||
async def _stream_from_llm(
|
||||
self, request: ChatRequest, state_machine: SSEStateMachine
|
||||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""
|
||||
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
|
||||
"""
|
||||
messages = self._build_messages(request)
|
||||
|
||||
async for chunk in self._llm_client.stream_generate(messages):
|
||||
if not state_machine.can_send_message():
|
||||
break
|
||||
|
||||
if chunk.delta:
|
||||
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
|
||||
yield create_message_event(delta=chunk.delta)
|
||||
|
||||
if chunk.finish_reason:
|
||||
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
|
||||
break
|
||||
|
||||
async def _stream_mock_response(
|
||||
self, request: ChatRequest, state_machine: SSEStateMachine
|
||||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""
|
||||
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
|
||||
Simulates LLM-style incremental output.
|
||||
"""
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
|
||||
|
||||
for part in reply_parts:
|
||||
if not state_machine.can_send_message():
|
||||
break
|
||||
|
||||
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
|
||||
yield create_message_event(delta=part)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
def _build_messages(self, request: ChatRequest) -> list[dict[str, str]]:
|
||||
"""Build messages list for LLM from request."""
|
||||
messages = []
|
||||
|
||||
if request.history:
|
||||
for msg in request.history:
|
||||
messages.append({
|
||||
"role": msg.role.value,
|
||||
"content": msg.content,
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": request.current_message,
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
|
||||
"""Extract delta content from a message event."""
|
||||
import json
|
||||
try:
|
||||
if event.data:
|
||||
data = json.loads(event.data)
|
||||
return data.get("delta", "")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
_orchestrator_service: OrchestratorService | None = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,287 @@
|
|||
"""
|
||||
Unit tests for Context Merger.
|
||||
[AC-AISVC-14, AC-AISVC-15] Tests for context merging and truncation.
|
||||
|
||||
Tests cover:
|
||||
- Message fingerprint computation
|
||||
- Context merging with deduplication
|
||||
- Token-based truncation
|
||||
- Complete merge_and_truncate pipeline
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models import ChatMessage, Role
|
||||
from app.services.context import ContextMerger, MergedContext, get_context_merger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Mock settings for testing."""
|
||||
settings = MagicMock()
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_merger(mock_settings):
|
||||
"""Create context merger with mocked settings."""
|
||||
with patch("app.services.context.get_settings", return_value=mock_settings):
|
||||
merger = ContextMerger(max_history_tokens=1000)
|
||||
yield merger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_history():
|
||||
"""Sample local history messages."""
|
||||
return [
|
||||
ChatMessage(role=Role.USER, content="Hello"),
|
||||
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
|
||||
ChatMessage(role=Role.USER, content="How are you?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def external_history():
|
||||
"""Sample external history messages."""
|
||||
return [
|
||||
ChatMessage(role=Role.USER, content="Hello"),
|
||||
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
|
||||
ChatMessage(role=Role.USER, content="What's the weather?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dict_local_history():
|
||||
"""Sample local history as dicts."""
|
||||
return [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dict_external_history():
|
||||
"""Sample external history as dicts."""
|
||||
return [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
|
||||
|
||||
class TestFingerprintComputation:
|
||||
"""Tests for message fingerprint computation. [AC-AISVC-15]"""
|
||||
|
||||
def test_fingerprint_consistency(self, context_merger):
|
||||
"""Test that same input produces same fingerprint."""
|
||||
fp1 = context_merger.compute_fingerprint("user", "Hello world")
|
||||
fp2 = context_merger.compute_fingerprint("user", "Hello world")
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_fingerprint_role_difference(self, context_merger):
|
||||
"""Test that different roles produce different fingerprints."""
|
||||
fp_user = context_merger.compute_fingerprint("user", "Hello")
|
||||
fp_assistant = context_merger.compute_fingerprint("assistant", "Hello")
|
||||
assert fp_user != fp_assistant
|
||||
|
||||
def test_fingerprint_content_difference(self, context_merger):
|
||||
"""Test that different content produces different fingerprints."""
|
||||
fp1 = context_merger.compute_fingerprint("user", "Hello")
|
||||
fp2 = context_merger.compute_fingerprint("user", "World")
|
||||
assert fp1 != fp2
|
||||
|
||||
def test_fingerprint_normalization(self, context_merger):
|
||||
"""Test that content is normalized (trimmed)."""
|
||||
fp1 = context_merger.compute_fingerprint("user", "Hello")
|
||||
fp2 = context_merger.compute_fingerprint("user", " Hello ")
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_fingerprint_is_sha256(self, context_merger):
|
||||
"""Test that fingerprint is SHA256 hash."""
|
||||
fp = context_merger.compute_fingerprint("user", "Hello")
|
||||
expected = hashlib.sha256("user|Hello".encode("utf-8")).hexdigest()
|
||||
assert fp == expected
|
||||
assert len(fp) == 64 # SHA256 produces 64 hex characters
|
||||
|
||||
|
||||
class TestContextMerging:
|
||||
"""Tests for context merging with deduplication. [AC-AISVC-14, AC-AISVC-15]"""
|
||||
|
||||
def test_merge_empty_histories(self, context_merger):
|
||||
"""[AC-AISVC-14] Test merging empty histories."""
|
||||
result = context_merger.merge_context(None, None)
|
||||
|
||||
assert isinstance(result, MergedContext)
|
||||
assert result.messages == []
|
||||
assert result.local_count == 0
|
||||
assert result.external_count == 0
|
||||
assert result.duplicates_skipped == 0
|
||||
|
||||
def test_merge_local_only(self, context_merger, local_history):
|
||||
"""[AC-AISVC-14] Test merging with only local history (no external)."""
|
||||
result = context_merger.merge_context(local_history, None)
|
||||
|
||||
assert len(result.messages) == 3
|
||||
assert result.local_count == 3
|
||||
assert result.external_count == 0
|
||||
assert result.duplicates_skipped == 0
|
||||
|
||||
def test_merge_external_only(self, context_merger, external_history):
|
||||
"""[AC-AISVC-15] Test merging with only external history (no local)."""
|
||||
result = context_merger.merge_context(None, external_history)
|
||||
|
||||
assert len(result.messages) == 3
|
||||
assert result.local_count == 0
|
||||
assert result.external_count == 3
|
||||
assert result.duplicates_skipped == 0
|
||||
|
||||
def test_merge_with_duplicates(self, context_merger, local_history, external_history):
|
||||
"""[AC-AISVC-15] Test deduplication when merging overlapping histories."""
|
||||
result = context_merger.merge_context(local_history, external_history)
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert result.local_count == 3
|
||||
assert result.external_count == 1
|
||||
assert result.duplicates_skipped == 2
|
||||
|
||||
roles = [m["role"] for m in result.messages]
|
||||
contents = [m["content"] for m in result.messages]
|
||||
assert "What's the weather?" in contents
|
||||
|
||||
def test_merge_with_dict_histories(self, context_merger, dict_local_history, dict_external_history):
|
||||
"""[AC-AISVC-14, AC-AISVC-15] Test merging with dict format histories."""
|
||||
result = context_merger.merge_context(dict_local_history, dict_external_history)
|
||||
|
||||
assert len(result.messages) == 3
|
||||
assert result.local_count == 2
|
||||
assert result.external_count == 1
|
||||
assert result.duplicates_skipped == 1
|
||||
|
||||
def test_merge_priority_local(self, context_merger):
|
||||
"""[AC-AISVC-15] Test that local history takes priority."""
|
||||
local = [ChatMessage(role=Role.USER, content="Hello")]
|
||||
external = [ChatMessage(role=Role.USER, content="Hello")]
|
||||
|
||||
result = context_merger.merge_context(local, external)
|
||||
|
||||
assert len(result.messages) == 1
|
||||
assert result.duplicates_skipped == 1
|
||||
|
||||
def test_merge_records_diagnostics(self, context_merger, local_history, external_history):
|
||||
"""[AC-AISVC-15] Test that duplicates are recorded in diagnostics."""
|
||||
result = context_merger.merge_context(local_history, external_history)
|
||||
|
||||
assert len(result.diagnostics) == 2
|
||||
for diag in result.diagnostics:
|
||||
assert diag["type"] == "duplicate_skipped"
|
||||
assert "role" in diag
|
||||
assert "content_preview" in diag
|
||||
|
||||
|
||||
class TestTokenTruncation:
|
||||
"""Tests for token-based truncation. [AC-AISVC-14]"""
|
||||
|
||||
def test_truncate_empty_messages(self, context_merger):
|
||||
"""[AC-AISVC-14] Test truncating empty message list."""
|
||||
truncated, count = context_merger.truncate_context([], 100)
|
||||
assert truncated == []
|
||||
assert count == 0
|
||||
|
||||
def test_truncate_within_budget(self, context_merger):
|
||||
"""[AC-AISVC-14] Test that messages within budget are not truncated."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
truncated, count = context_merger.truncate_context(messages, 1000)
|
||||
|
||||
assert len(truncated) == 2
|
||||
assert count == 0
|
||||
|
||||
def test_truncate_exceeds_budget(self, context_merger):
|
||||
"""[AC-AISVC-14] Test that messages exceeding budget are truncated."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello world " * 100},
|
||||
{"role": "assistant", "content": "Hi there " * 100},
|
||||
{"role": "user", "content": "Short message"},
|
||||
]
|
||||
truncated, count = context_merger.truncate_context(messages, 50)
|
||||
|
||||
assert len(truncated) < len(messages)
|
||||
assert count > 0
|
||||
|
||||
def test_truncate_keeps_recent_messages(self, context_merger):
|
||||
"""[AC-AISVC-14] Test that truncation keeps most recent messages."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "Second message"},
|
||||
{"role": "user", "content": "Third message"},
|
||||
]
|
||||
truncated, count = context_merger.truncate_context(messages, 20)
|
||||
|
||||
if count > 0:
|
||||
assert "Third message" in [m["content"] for m in truncated]
|
||||
|
||||
def test_truncate_with_default_budget(self, context_merger):
|
||||
"""[AC-AISVC-14] Test truncation with default budget from config."""
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
truncated, count = context_merger.truncate_context(messages)
|
||||
|
||||
assert len(truncated) == 1
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestMergeAndTruncate:
|
||||
"""Tests for complete merge_and_truncate pipeline. [AC-AISVC-14, AC-AISVC-15]"""
|
||||
|
||||
def test_merge_and_truncate_combined(self, context_merger):
|
||||
"""[AC-AISVC-14, AC-AISVC-15] Test complete pipeline."""
|
||||
local = [
|
||||
ChatMessage(role=Role.USER, content="Hello"),
|
||||
ChatMessage(role=Role.ASSISTANT, content="Hi"),
|
||||
]
|
||||
external = [
|
||||
ChatMessage(role=Role.USER, content="Hello"),
|
||||
ChatMessage(role=Role.USER, content="What's up?"),
|
||||
]
|
||||
|
||||
result = context_merger.merge_and_truncate(local, external, max_tokens=1000)
|
||||
|
||||
assert isinstance(result, MergedContext)
|
||||
assert len(result.messages) == 3
|
||||
assert result.local_count == 2
|
||||
assert result.external_count == 1
|
||||
assert result.duplicates_skipped == 1
|
||||
|
||||
def test_merge_and_truncate_with_truncation(self, context_merger):
|
||||
"""[AC-AISVC-14, AC-AISVC-15] Test pipeline with truncation."""
|
||||
local = [
|
||||
ChatMessage(role=Role.USER, content="Hello " * 50),
|
||||
ChatMessage(role=Role.ASSISTANT, content="Hi " * 50),
|
||||
]
|
||||
external = [
|
||||
ChatMessage(role=Role.USER, content="Short"),
|
||||
]
|
||||
|
||||
result = context_merger.merge_and_truncate(local, external, max_tokens=50)
|
||||
|
||||
assert result.truncated_count > 0
|
||||
assert result.total_tokens <= 50
|
||||
|
||||
|
||||
class TestContextMergerSingleton:
|
||||
"""Tests for singleton pattern."""
|
||||
|
||||
def test_get_context_merger_singleton(self, mock_settings):
|
||||
"""Test that get_context_merger returns singleton."""
|
||||
with patch("app.services.context.get_settings", return_value=mock_settings):
|
||||
from app.services.context import _context_merger
|
||||
import app.services.context as context_module
|
||||
context_module._context_merger = None
|
||||
|
||||
merger1 = get_context_merger()
|
||||
merger2 = get_context_merger()
|
||||
|
||||
assert merger1 is merger2
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
Tests for SSE event generator.
|
||||
[AC-AISVC-07] Tests for message event generation with delta content.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from sse_starlette.sse import ServerSentEvent
|
||||
|
||||
from app.core.sse import (
|
||||
create_message_event,
|
||||
create_final_event,
|
||||
create_error_event,
|
||||
SSEStateMachine,
|
||||
SSEState,
|
||||
)
|
||||
from app.services.orchestrator import OrchestratorService
|
||||
from app.models import ChatRequest, ChannelType
|
||||
|
||||
|
||||
class TestSSEEventGenerator:
|
||||
"""
|
||||
[AC-AISVC-07] Test cases for SSE event generation.
|
||||
"""
|
||||
|
||||
def test_create_message_event_format(self):
|
||||
"""
|
||||
[AC-AISVC-07] Test that message event has correct format.
|
||||
Event should have:
|
||||
- event: "message"
|
||||
- data: JSON with "delta" field
|
||||
"""
|
||||
event = create_message_event(delta="Hello, ")
|
||||
|
||||
assert event.event == "message"
|
||||
assert event.data is not None
|
||||
|
||||
data = json.loads(event.data)
|
||||
assert "delta" in data
|
||||
assert data["delta"] == "Hello, "
|
||||
|
||||
def test_create_message_event_with_unicode(self):
|
||||
"""
|
||||
[AC-AISVC-07] Test that message event handles unicode correctly.
|
||||
"""
|
||||
event = create_message_event(delta="你好,世界!")
|
||||
|
||||
assert event.event == "message"
|
||||
data = json.loads(event.data)
|
||||
assert data["delta"] == "你好,世界!"
|
||||
|
||||
def test_create_message_event_with_empty_delta(self):
|
||||
"""
|
||||
[AC-AISVC-07] Test that message event handles empty delta.
|
||||
"""
|
||||
event = create_message_event(delta="")
|
||||
|
||||
assert event.event == "message"
|
||||
data = json.loads(event.data)
|
||||
assert data["delta"] == ""
|
||||
|
||||
def test_create_final_event_format(self):
|
||||
"""
|
||||
[AC-AISVC-08] Test that final event has correct format.
|
||||
"""
|
||||
event = create_final_event(
|
||||
reply="Complete response",
|
||||
confidence=0.85,
|
||||
should_transfer=False,
|
||||
)
|
||||
|
||||
assert event.event == "final"
|
||||
data = json.loads(event.data)
|
||||
assert data["reply"] == "Complete response"
|
||||
assert data["confidence"] == 0.85
|
||||
assert data["shouldTransfer"] is False
|
||||
|
||||
def test_create_final_event_with_transfer_reason(self):
|
||||
"""
|
||||
[AC-AISVC-08] Test final event with transfer reason.
|
||||
"""
|
||||
event = create_final_event(
|
||||
reply="I cannot help with this",
|
||||
confidence=0.3,
|
||||
should_transfer=True,
|
||||
transfer_reason="Low confidence score",
|
||||
)
|
||||
|
||||
assert event.event == "final"
|
||||
data = json.loads(event.data)
|
||||
assert data["shouldTransfer"] is True
|
||||
assert data["transferReason"] == "Low confidence score"
|
||||
|
||||
def test_create_error_event_format(self):
|
||||
"""
|
||||
[AC-AISVC-09] Test that error event has correct format.
|
||||
"""
|
||||
event = create_error_event(
|
||||
code="GENERATION_ERROR",
|
||||
message="Failed to generate response",
|
||||
)
|
||||
|
||||
assert event.event == "error"
|
||||
data = json.loads(event.data)
|
||||
assert data["code"] == "GENERATION_ERROR"
|
||||
assert data["message"] == "Failed to generate response"
|
||||
|
||||
def test_create_error_event_with_details(self):
|
||||
"""
|
||||
[AC-AISVC-09] Test error event with details.
|
||||
"""
|
||||
event = create_error_event(
|
||||
code="VALIDATION_ERROR",
|
||||
message="Invalid input",
|
||||
details=[{"field": "message", "error": "too long"}],
|
||||
)
|
||||
|
||||
assert event.event == "error"
|
||||
data = json.loads(event.data)
|
||||
assert data["details"] == [{"field": "message", "error": "too long"}]
|
||||
|
||||
|
||||
class TestOrchestratorStreaming:
|
||||
"""
|
||||
[AC-AISVC-07] Test cases for orchestrator streaming with SSE events.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(self):
|
||||
return OrchestratorService()
|
||||
|
||||
@pytest.fixture
|
||||
def chat_request(self):
|
||||
return ChatRequest(
|
||||
session_id="test_session",
|
||||
current_message="Hello",
|
||||
channel_type=ChannelType.WECHAT,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_yields_message_events(self, orchestrator, chat_request):
|
||||
"""
|
||||
[AC-AISVC-07] Test that streaming yields message events with delta content.
|
||||
"""
|
||||
events = []
|
||||
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||
events.append(event)
|
||||
|
||||
message_events = [e for e in events if e.event == "message"]
|
||||
final_events = [e for e in events if e.event == "final"]
|
||||
|
||||
assert len(message_events) > 0, "Should have at least one message event"
|
||||
assert len(final_events) == 1, "Should have exactly one final event"
|
||||
|
||||
for event in message_events:
|
||||
data = json.loads(event.data)
|
||||
assert "delta" in data
|
||||
assert isinstance(data["delta"], str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_message_events_contain_content(self, orchestrator, chat_request):
|
||||
"""
|
||||
[AC-AISVC-07] Test that message events contain the expected content.
|
||||
"""
|
||||
events = []
|
||||
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||
events.append(event)
|
||||
|
||||
message_events = [e for e in events if e.event == "message"]
|
||||
|
||||
full_content = ""
|
||||
for event in message_events:
|
||||
data = json.loads(event.data)
|
||||
full_content += data["delta"]
|
||||
|
||||
assert "Hello" in full_content, "Content should contain the user message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_event_sequence(self, orchestrator, chat_request):
|
||||
"""
|
||||
[AC-AISVC-07, AC-AISVC-08] Test that events follow proper sequence.
|
||||
message* -> final -> close
|
||||
"""
|
||||
events = []
|
||||
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||
events.append(event)
|
||||
|
||||
event_types = [e.event for e in events]
|
||||
|
||||
final_index = event_types.index("final")
|
||||
message_indices = [i for i, t in enumerate(event_types) if t == "message"]
|
||||
|
||||
for msg_idx in message_indices:
|
||||
assert msg_idx < final_index, "All message events should come before final"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_llm_client(self, chat_request):
|
||||
"""
|
||||
[AC-AISVC-07] Test streaming with mock LLM client.
|
||||
"""
|
||||
mock_llm = MagicMock()
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.delta = "Hello"
|
||||
mock_chunk1.finish_reason = None
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.delta = " there!"
|
||||
mock_chunk2.finish_reason = None
|
||||
|
||||
mock_chunk3 = MagicMock()
|
||||
mock_chunk3.delta = ""
|
||||
mock_chunk3.finish_reason = "stop"
|
||||
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
|
||||
yield chunk
|
||||
|
||||
mock_llm.stream_generate = mock_stream
|
||||
|
||||
orchestrator = OrchestratorService(llm_client=mock_llm)
|
||||
|
||||
events = []
|
||||
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||
events.append(event)
|
||||
|
||||
message_events = [e for e in events if e.event == "message"]
|
||||
assert len(message_events) == 2, "Should have two message events"
|
||||
|
||||
full_content = ""
|
||||
for event in message_events:
|
||||
data = json.loads(event.data)
|
||||
full_content += data["delta"]
|
||||
|
||||
assert full_content == "Hello there!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_handles_error(self, orchestrator, chat_request):
|
||||
"""
|
||||
[AC-AISVC-09] Test that streaming errors are converted to error events.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestSSEStateMachineIntegration:
|
||||
"""
|
||||
[AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Integration tests for SSE state machine.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_machine_prevents_events_after_final(self):
|
||||
"""
|
||||
[AC-AISVC-08] Test that no events can be sent after final.
|
||||
"""
|
||||
state_machine = SSEStateMachine()
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
assert state_machine.can_send_message() is True
|
||||
|
||||
await state_machine.transition_to_final()
|
||||
|
||||
assert state_machine.can_send_message() is False
|
||||
assert state_machine.state == SSEState.FINAL_SENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_machine_prevents_events_after_error(self):
|
||||
"""
|
||||
[AC-AISVC-09] Test that no events can be sent after error.
|
||||
"""
|
||||
state_machine = SSEStateMachine()
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
await state_machine.transition_to_error()
|
||||
|
||||
assert state_machine.can_send_message() is False
|
||||
assert state_machine.state == SSEState.ERROR_SENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_machine_allows_multiple_message_events(self):
|
||||
"""
|
||||
[AC-AISVC-07] Test that multiple message events can be sent during streaming.
|
||||
"""
|
||||
state_machine = SSEStateMachine()
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
for _ in range(5):
|
||||
assert state_machine.can_send_message() is True
|
||||
|
||||
await state_machine.transition_to_final()
|
||||
assert state_machine.can_send_message() is False
|
||||
Loading…
Reference in New Issue