377 lines
12 KiB
Python
377 lines
12 KiB
Python
|
|
"""
|
||
|
|
Tests for SSE state machine and error handling.
|
||
|
|
[AC-AISVC-08, AC-AISVC-09] Tests for proper event sequence and error handling.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
from sse_starlette.sse import ServerSentEvent
|
||
|
|
|
||
|
|
from app.core.sse import (
|
||
|
|
SSEState,
|
||
|
|
SSEStateMachine,
|
||
|
|
create_error_event,
|
||
|
|
create_final_event,
|
||
|
|
create_message_event,
|
||
|
|
)
|
||
|
|
from app.main import app
|
||
|
|
from app.models import ChatRequest, ChannelType
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEStateMachineTransitions:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine transitions.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_init_to_streaming_transition(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test INIT -> STREAMING transition.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
assert state_machine.state == SSEState.INIT
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_streaming()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.STREAMING
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_streaming_to_final_transition(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test STREAMING -> FINAL_SENT transition.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_final()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.FINAL_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_streaming_to_error_transition(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test STREAMING -> ERROR_SENT transition.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_error()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.ERROR_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_init_to_error_transition(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test INIT -> ERROR_SENT transition (error before streaming starts).
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_error()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.ERROR_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cannot_transition_from_final(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that no transitions are possible after FINAL_SENT.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
await state_machine.transition_to_final()
|
||
|
|
|
||
|
|
assert await state_machine.transition_to_streaming() is False
|
||
|
|
assert await state_machine.transition_to_error() is False
|
||
|
|
assert state_machine.state == SSEState.FINAL_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cannot_transition_from_error(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test that no transitions are possible after ERROR_SENT.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
await state_machine.transition_to_error()
|
||
|
|
|
||
|
|
assert await state_machine.transition_to_streaming() is False
|
||
|
|
assert await state_machine.transition_to_final() is False
|
||
|
|
assert state_machine.state == SSEState.ERROR_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cannot_send_message_after_final(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that can_send_message returns False after FINAL_SENT.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
await state_machine.transition_to_final()
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cannot_send_message_after_error(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test that can_send_message returns False after ERROR_SENT.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
await state_machine.transition_to_error()
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_close_transition(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that close() transitions to CLOSED state.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
await state_machine.transition_to_final()
|
||
|
|
|
||
|
|
await state_machine.close()
|
||
|
|
assert state_machine.state == SSEState.CLOSED
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEEventSequence:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE event sequence enforcement.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(self):
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def valid_headers(self):
|
||
|
|
return {"X-Tenant-Id": "tenant_001", "Accept": "text/event-stream"}
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def valid_body(self):
|
||
|
|
return {
|
||
|
|
"sessionId": "test_session",
|
||
|
|
"currentMessage": "Hello",
|
||
|
|
"channelType": "wechat",
|
||
|
|
}
|
||
|
|
|
||
|
|
def test_sse_sequence_message_then_final(self, client, valid_headers, valid_body):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that SSE events follow: message* -> final -> close.
|
||
|
|
"""
|
||
|
|
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
content = response.text
|
||
|
|
|
||
|
|
assert "event:message" in content or "event: message" in content
|
||
|
|
assert "event:final" in content or "event: final" in content
|
||
|
|
|
||
|
|
message_idx = content.find("event:message")
|
||
|
|
if message_idx == -1:
|
||
|
|
message_idx = content.find("event: message")
|
||
|
|
final_idx = content.find("event:final")
|
||
|
|
if final_idx == -1:
|
||
|
|
final_idx = content.find("event: final")
|
||
|
|
|
||
|
|
assert final_idx > message_idx, "final should come after message events"
|
||
|
|
|
||
|
|
def test_sse_only_one_final_event(self, client, valid_headers, valid_body):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that there is exactly one final event.
|
||
|
|
"""
|
||
|
|
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
|
||
|
|
|
||
|
|
content = response.text
|
||
|
|
final_count = content.count("event:final") + content.count("event: final")
|
||
|
|
|
||
|
|
assert final_count == 1, f"Expected exactly 1 final event, got {final_count}"
|
||
|
|
|
||
|
|
def test_sse_no_events_after_final(self, client, valid_headers, valid_body):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that no message events appear after final event.
|
||
|
|
"""
|
||
|
|
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
|
||
|
|
|
||
|
|
content = response.text
|
||
|
|
lines = content.split("\n")
|
||
|
|
|
||
|
|
final_found = False
|
||
|
|
for line in lines:
|
||
|
|
if "event:final" in line or "event: final" in line:
|
||
|
|
final_found = True
|
||
|
|
elif final_found and ("event:message" in line or "event: message" in line):
|
||
|
|
pytest.fail("Found message event after final event")
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEErrorHandling:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test cases for SSE error handling.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_error_event_format(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test error event format.
|
||
|
|
"""
|
||
|
|
event = create_error_event(
|
||
|
|
code="TEST_ERROR",
|
||
|
|
message="Test error message",
|
||
|
|
details=[{"field": "test"}],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert event.event == "error"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["code"] == "TEST_ERROR"
|
||
|
|
assert data["message"] == "Test error message"
|
||
|
|
assert data["details"] == [{"field": "test"}]
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_error_event_without_details(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test error event without details.
|
||
|
|
"""
|
||
|
|
event = create_error_event(
|
||
|
|
code="SIMPLE_ERROR",
|
||
|
|
message="Simple error",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert event.event == "error"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["code"] == "SIMPLE_ERROR"
|
||
|
|
assert data["message"] == "Simple error"
|
||
|
|
assert "details" not in data
|
||
|
|
|
||
|
|
def test_missing_tenant_id_returns_400(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
|
||
|
|
"""
|
||
|
|
client = TestClient(app)
|
||
|
|
headers = {"Accept": "text/event-stream"}
|
||
|
|
body = {
|
||
|
|
"sessionId": "test_session",
|
||
|
|
"currentMessage": "Hello",
|
||
|
|
"channelType": "wechat",
|
||
|
|
}
|
||
|
|
|
||
|
|
response = client.post("/ai/chat", json=body, headers=headers)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
data = response.json()
|
||
|
|
assert data["code"] == "MISSING_TENANT_ID"
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEStateConcurrency:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08, AC-AISVC-09] Test cases for state machine thread safety.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_concurrent_transitions(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that concurrent transitions are handled correctly.
|
||
|
|
"""
|
||
|
|
import asyncio
|
||
|
|
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
results = []
|
||
|
|
|
||
|
|
async def try_transition():
|
||
|
|
success = await state_machine.transition_to_streaming()
|
||
|
|
results.append(success)
|
||
|
|
|
||
|
|
await asyncio.gather(
|
||
|
|
try_transition(),
|
||
|
|
try_transition(),
|
||
|
|
try_transition(),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert sum(results) == 1, "Only one transition should succeed"
|
||
|
|
assert state_machine.state == SSEState.STREAMING
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_concurrent_final_transitions(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that only one final transition succeeds.
|
||
|
|
"""
|
||
|
|
import asyncio
|
||
|
|
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
results = []
|
||
|
|
|
||
|
|
async def try_final():
|
||
|
|
success = await state_machine.transition_to_final()
|
||
|
|
results.append(success)
|
||
|
|
|
||
|
|
await asyncio.gather(
|
||
|
|
try_final(),
|
||
|
|
try_final(),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert sum(results) == 1, "Only one final transition should succeed"
|
||
|
|
assert state_machine.state == SSEState.FINAL_SENT
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEIntegrationWithOrchestrator:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08, AC-AISVC-09] Integration tests for SSE with Orchestrator.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_orchestrator_stream_with_error(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test that orchestrator errors are properly handled.
|
||
|
|
"""
|
||
|
|
from app.services.orchestrator import OrchestratorService
|
||
|
|
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
|
||
|
|
async def failing_stream(*args, **kwargs):
|
||
|
|
yield MagicMock(delta="Hello", finish_reason=None)
|
||
|
|
raise Exception("LLM connection lost")
|
||
|
|
|
||
|
|
mock_llm.stream_generate = failing_stream
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(llm_client=mock_llm)
|
||
|
|
request = ChatRequest(
|
||
|
|
session_id="test",
|
||
|
|
current_message="Hi",
|
||
|
|
channel_type=ChannelType.WECHAT,
|
||
|
|
)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("tenant", request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
event_types = [e.event for e in events]
|
||
|
|
assert "message" in event_types
|
||
|
|
assert "error" in event_types
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_orchestrator_stream_normal_flow(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test normal streaming flow ends with final event.
|
||
|
|
"""
|
||
|
|
from app.services.orchestrator import OrchestratorService
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService()
|
||
|
|
request = ChatRequest(
|
||
|
|
session_id="test",
|
||
|
|
current_message="Hi",
|
||
|
|
channel_type=ChannelType.WECHAT,
|
||
|
|
)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("tenant", request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
event_types = [e.event for e in events]
|
||
|
|
assert "message" in event_types
|
||
|
|
assert "final" in event_types
|
||
|
|
|
||
|
|
final_index = event_types.index("final")
|
||
|
|
for i, t in enumerate(event_types):
|
||
|
|
if t == "message":
|
||
|
|
assert i < final_index, "message events should come before final"
|